Open in Colab

Bayesian Optimization Modeling

The goal of this tutorial is to introduce Bayesian optimization workflows in OSS Vizier, including the underlying TensorFlow Probability (TFP) components and JAX/Flax functionality. The target audience is researchers and practitioners already well-versed in Bayesian optimization, who want to define and train their own Gaussian Process surrogate models for Bayesian optimization in OSS Vizier.

Additional resources for TFP

If you’re new to TFP, a good place to start is A tour of TensorFlow Probability. TFP began as a TensorFlow-only library, but now has a JAX backend that is entirely independent of TensorFlow (such that “Tensor-Friendly Probability” might be a better backronym). This Colab uses TFP’s JAX backend (see the “Imports” cell for how to import it).

Additional resources for Flax

OSS Vizier’s Bayesian Optimization models are defined as Flax modules.


import chex
import jax
from jax import numpy as jnp, random, tree_util
import numpy as np
import optax
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from tensorflow_probability.substrates import jax as tfp
from typing import Any

# Vizier models can freely access modules from vizier._src
from vizier._src.benchmarks.experimenters.synthetic import bbob
from vizier.jax import optimizers
from vizier._src.jax import stochastic_process_model as spm

tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels

Defining a GP surrogate model and hyperparameters

To write a GP surrogate model, first write a coroutine that yields parameter specifications (ModelParameter) and returns a GP distribution. Downstream, the parameter specifications are used to define Flax module parameters. The inputs to the coroutine function represent the index points of the GP (in the remainder of this Colab, we refer to “inputs” and “index points” interchangeably).

The rationale for the coroutine design is that it lets us automate the application of the parameter constraint and initialization functions (corresponding to hyperpriors, e.g.), and enables simultaneous specification of the model parameters and how their values are used to instantiate a GP.

Coroutine example

The following cell shows a coroutine defining a GP with a squared exponential kernel and two parameters: the length scale of the kernel and the observation noise variance of the GP.

def simple_gp_coroutine(inputs: chex.Array=None):
  length_scale = yield spm.ModelParameter.from_prior(
      tfd.Gamma(1., 1., name='length_scale'))
  amplitude = 2.  # Non-trainable parameters may be defined as constants.
  kernel = tfpk.ExponentiatedQuadratic(
      amplitude=amplitude, length_scale=length_scale)
  observation_noise_variance = yield spm.ModelParameter(
      init_fn=lambda x: jnp.exp(random.normal(x)),
      constraint=spm.Constraint(bounds=(0.0, 100.0), bijector=tfb.Softplus()),
      regularizer=lambda x: x**2,
  return tfd.GaussianProcess(


ModelParameter may be used to define hyperpriors.

Parameter specifications from priors

The length scale parameter has a Gamma prior. This is equivalent to defining a ModelParameter with a regularizer that computes the Gamma negative log likelihood and an initialization function that samples from the Gamma distribution. As the constraint was not specified, a default one is assigned which is the “default event space bijector” of the TFP distribution (each TFP distribution has a constraining bijector that maps the real line to the support of the distribution).

Specifying parameters explicitly

Observation noise variance, which is passed to the Gaussian process and represents the scalar variance of zero-mean Gaussian noise in the observed labels, is not given a tfd.Distribution prior. Instead, it has its initialization, constraining, and regularization functions defined individually. Note that the initialization function is in the constrained space.


ModelParameter allows to define constraints on the model parameters using the ‘Constraint’ object which is initiated with a tuple of ‘bounds’ and ‘bijector’ function.

Though the constraints are defined as part of the ModelParameter the Flax model itself does not use them, but rather it expects to receive parameter values already in the constrained space. This means that it’s the responsibility of the user/optimizer to pass the GP parameter values that are already in the constrained space.

Exercise: Write a GP model

Write an ARD Gaussian Process model with three parameters: signal_variance, length_scale, and observation_noise_variance. (This is a slightly simplified version of the Vizier GP.)

  • signal_variance and observation noise_variance are both:

    • regularized by the function \(f(x) = 0.01\log(x)^2\)

    • bounded to be positive.

  • signal_variance parameterizes a Matern 5/2 kernel, where the amplitude of the kernel is the square root of signal_variance. Use tfpk.MaternFiveHalves.

  • length_scale has a \(LogNormal(0, 1)\) prior for each dimension. Assume there are 4 dimensions, and use tfd.Sample to build a 4-dimensional distribution consisting of IID LogNormal distributions. (Note that the length_scale parameter is a vector – all other parameters are scalars.)

  • In TFP, ARD kernels are implemented with tfpk.FeatureScaled, with scale_diag representing the length scale along each dimension.

def vizier_gp_coroutine(inputs: chex.Array=None):


data_dimensionality = 2

def vizier_gp_coroutine(inputs: chex.Array=None):
  """A coroutine that follows the `ModelCoroutine` protocol."""
  signal_variance = yield spm.ModelParameter(
      init_fn=lambda x: tfb.Softplus()(random.normal(x)),
      constraint=spm.Constraint(bounds=(0.0, 100.0), bijector=tfb.Softplus()),
      regularizer=lambda x: 0.01 * jnp.log(x)**2,
  length_scale = yield spm.ModelParameter.from_prior(
        tfd.LogNormal(loc=0., scale=1.),
    constraint=spm.Constraint(bounds=(0.0, None)))
  kernel = tfpk.MaternFiveHalves(
      amplitude=jnp.sqrt(signal_variance), validate_args=True)
  kernel = tfpk.FeatureScaled(
      kernel, scale_diag=length_scale, validate_args=True)
  observation_noise_variance = yield spm.ModelParameter(
      init_fn=lambda x: jnp.exp(random.normal(x)),
      constraint=spm.Constraint(bounds=(0.0, 100.0), bijector=tfb.Softplus()),
      regularizer=lambda x: 0.01 * jnp.log(x)**2,
  return tfd.GaussianProcess(

To build a GP Flax module, instantiate a StochasticProcessModel with a GP coroutine as shown below. The module runs the coroutine in the setup and __call__ methods to initialize the parameters and then instantiate the GP object with the given parameters.

Recall that Flax modules have two primary methods: init, which initializes parameters, and apply, which computes the model’s forward pass given a set of parameters and input data.

model = spm.StochasticProcessModel(coroutine=vizier_gp_coroutine)

# Sample some fake data.
# Assume we have `num_points` observations, each with `dim` features.
num_points = 12

# Sample a set of index points.
index_points = np.random.normal(
    size=[num_points, data_dimensionality]).astype(np.float32)

# Sample function values observed at the index points
observations = np.random.normal(size=[num_points]).astype(np.float32)

# Call the Flax module's `init` method to obtain initial parameter values.
init_params = model.init(random.PRNGKey(0), index_points)

We can observe the initial parameters values of the Flax model and see that they match with the ‘ModelParameter’ definitions in our coroutine.


To instantiate a GP with a set of parameters and index points, use the Flax module’s apply method. apply also returns the regularization losses for the parameters, in mutables. The regularization losses are treated as mutable state because they are recomputed internally with each forward pass of the model. For more on mutable state in Flax, see this tutorial.

gp, mutables = model.apply(
assert isinstance(gp, tfd.GaussianProcess)

Optimizing hyperparameters

Exercise: Loss function

Write down a loss function that takes a parameters dict and returns the loss value, using model.apply. The function will close over the observed data.

The loss should be the sum of the GP negative log likelihood and the regularization losses. The regularization loss values are computed when the module is called, using the ModelParameter regularization functions. They are stored in a mutable variable collection called "losses", using the Flax method sow.

def loss_fn(params):
  return loss, {}  # Return an empty dict as auxiliary state.


def loss_fn(params):
  gp, mutables = model.apply({'params': params},
  loss = (-gp.log_prob(observations) +
          jax.tree_util.tree_reduce(jnp.add, mutables['losses']))  # add the regularization losses.
  return loss, {}

The gradients of the loss have the same structure as the params dict.

grads = jax.grad(loss_fn, has_aux=True)(init_params['params'])[0]

We can use jax.tree_util to take a step along the gradient (though in practice, with Optax, we can use update and apply_updates to update the parameters at each train step).

learning_rate = 1e-3
updated_params = jax.tree_util.tree_map(
    lambda p, g: p - learning_rate * g,

Optimize hyperparameters with Vizier optimizers

Flax modules are often optimized using Optax which requires the developer to write a routine that initializes parameter values and then repeatedly computes the loss function gradients and updates the parameter values accordingly.

Vizier Optimizers is a library of optimizers that automate the process of finding the optimal Flax parameter values and wrap optimizers from libraries such as Optax and Jaxopt in a common interface. To use a Vizier Optimizer you have to specify the following:

  • setup function which is used to generate the initial parameter values.

  • loss_fn function which is used for computing the loss function value and gradients. For example, the loss function of a GP model would be a marginal likelihood plus the parameters regularizations.

  • rng PRNGKey for controlling pseudo randomization.

  • constraints on the parameters (optional).

Below we use the Vizier JaxoptLbfgsB optimizer to run a constrained L-BFGS-B algorithm. Unconstrainted optimizers (e.g. Adam) use a bijector function to map between the unconstrained space where the search is performed, and the constrained space where the loss function is evaluated. On the contrary, constrained optimizers (e.g. L-BGFS-B) use the constraint bounds directly in the search process.

To pass the constraints bounds to the JaxoptLbfgsB optimizer we use the spm.get_constraints function that traverse the parameters defined in the module coroutine and extract their bounds.

setup = lambda rng: model.init(rng, index_points)['params']
model_optimizer = optimizers.JaxoptLbfgsB(
      random_restarts=20, best_n=None
constraints = spm.get_constraints(model)
optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),

Predict on new inputs, conditional on observations

To compute the posterior predictive GP on unseen points, conditioned on observed data, use the precompute_predictive and posterior_predictive methods of the Flax module. precompute_predictive must be called first; it runs and stores the Cholesky decomposition of the kernel matrix for the observed data. posterior_predictive then returns a posterior predictive GP at new index points, avoiding recomputation of the Cholesky.

# Precompute the Cholesky.
_, pp_state = model.apply(
    {'params': optimal_params},

# Predict on new index points.
predictive_index_points = np.random.normal(
    size=[5, data_dimensionality]).astype(np.float32)
pp_dist = model.apply(
    {'params': optimal_params, **pp_state},

# `posterior_predictive` returns a TFP distribution, whose mean, variance, and 
# samples we can use to compute an acquisition function.
assert pp_dist.mean().shape == (5,)

Optimize a black-box function

For an end-to-end example of Bayesian optimization, we’ll use the GP surrogate model defined above along with an Upper Confidence Bound acquisition function to try to find the maximum of the Weierstrass function. First, visualize the function surface.

# Use the Weierstrass function from Vizier's Black-Box Optimization Benchmarking
# (BBOB) library.
bb_fun = bbob.Weierstrass

# Sample a set of index points in a 2D space.
num_points = 6
max_x = np.array(2.).astype(np.float32)
index_points = random.uniform(
    shape=[num_points, data_dimensionality], dtype=jnp.float32) * max_x

# Compute function values observed at the index points.
observations = np.apply_along_axis(
    bb_fun, axis=1, arr=index_points).astype(np.float32)

# Define a grid of points in the function domain for plotting.
n_grid = 100
x = y = np.linspace(0, max_x, n_grid, dtype=np.float32)
X, Y = np.meshgrid(x, y)
x_grid = np.vstack([X.ravel(), Y.ravel()]).T
y_grid = np.apply_along_axis(bb_fun, axis=1, arr=x_grid)
Z = y_grid.reshape(X.shape)

# Plot the black-box function values.
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, alpha=0.5)
ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r', 
           label='Initial observed data')
plt.title('Black-box (Weierstrass) function values and observed data')

Next, run a few iterations of Bayesian optimization to maximize the black-box function given the observed data. A single iteration consists of the following steps:

  1. Optimize the GP hyperparameters.

  2. Find a suggestion that maximizes an Upper Confidence Bound acquisition function. In this example, we use grid search for the optimization.

  3. Evaluate the black-box function on the suggestion and append it to the set of observed data.

(Note that this simple Bayesopt algorithm is for educational purposes and that we’d expect Vizier’s GP bandit algorithm to give better results.)

num_bayesopt_iter = 5

# At each iteration, redefine the loss function given the current observed data.
def build_loss_fn(index_points, observations):
  def loss_fn(params):
    gp, mutables = model.apply({'params': params},
    loss = (-gp.log_prob(observations) +
            jax.tree_util.tree_reduce(jnp.add, mutables['losses']))  # add the regularization losses.
    return loss, {}
  return loss_fn

for i in range(num_bayesopt_iter):
  # Update the loss function to condition on all observed data.
  loss_fn = build_loss_fn(index_points, observations)

  # Optimize the GP hyperparameters.
  optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),

  # Compute the posterior predictive distribution over a grid of points in the
  # function domain (x_grid).
  _, pp_state = model.apply(
      {'params': optimal_params},
  pp_dist = model.apply(
      {'params': optimal_params, **pp_state},

  # Compute the acquisition function value at each point in the grid.
  pred_mean = pp_dist.mean()
  ucb_vec = pred_mean + 2. * pp_dist.stddev()

  # Find the grid point with the highest acquisition function value.
  ind = np.argmax(ucb_vec)

  # Evaluate the black box function at the selected point. 
  f_val = bb_fun(x_grid[ind])

  # Visualize the surrogate model mean and acquisition function surface at this
  # iteration.
  fig = plt.figure(figsize=(10, 10))
  ax = fig.add_subplot(121, projection='3d')
  W = pred_mean.reshape(X.shape)
  ax.plot_surface(X, Y, W, alpha=0.5)
  ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r',
             label='Observed data')
  ax.set_title('Observed data and posterior predictive GP mean')
  ax = fig.add_subplot(122, projection='3d')
  ucb = ucb_vec.reshape(X.shape)
  ax.plot_surface(X, Y, ucb, alpha=0.5)
  ax.scatter(*x_grid[ind], ucb_vec[ind], color='r', label='New suggestion')
  ax.set_title('Acquisition function')

  # Append the new suggestion and function value to the set of observations.
  index_points = np.concatenate([index_points, x_grid[ind][np.newaxis]])
  observations = np.concatenate(
      [observations, np.array(f_val).astype(np.float32)[np.newaxis]])
  print(f'Iteration: {i}')
  print(f'Acquisition function value at suggestion: {ucb_vec[ind]}')
  print(f'Black-box function value at suggestion: {f_val}')

Deeper dive on selected topics in TFP

As shown above, the Flax GP model makes use of a number of TFP components:

  • Distributions specify parameter priors (e.g. tfd.Gamma). The stochastic process model itself is also a TFP distribution, tfd.GaussianProcess.

  • Bijectors (e.g. tfb.Softplus) are used to constrain parameters for optimization, and may also be used for input/output warping.

  • PSD kernels (e.g. tfpk.ExponentiatedQuadratic) specify the kernel function for the stochastic process.

The next sections of this Colab introduce these and how they’re used in Bayesopt modeling.

tfd.GaussianProcess and friends

The stochastic process Flax modules return a TFP distribution in the Gaussian Process family (an instance of tfd.GaussianProcess, tfd.StudentTProcess, or tfde.MultiTaskGaussianProcess).

This Colab doesn’t go into detail on TFP distributions, since advanced usage and implementation of distributions is rarely required for Bayesopt modeling with Vizier. For an overview of TFP distributions, see TensorFlow Distributions: A Gentle Introduction.

Some of the methods of the Gaussian Process distribution are demonstrated below. Gaussian Process Regression in TFP is also worth reading.

# Build a kernel function (see "PSD kernels" section below) and GP.
num_points = 6
index_points = random.uniform(
    shape=[num_points, data_dimensionality], dtype=jnp.float32) 
observations = random.uniform(
    shape=[num_points], dtype=jnp.float32) 

kernel = tfpk.MaternFiveHalves(
    validate_args=True  # Run additional runtime checks; possibly expensive.
observation_noise_variance = jnp.ones([], dtype=observations.dtype)
gp = tfd.GaussianProcess(
    cholesky_fn=lambda x: tfp.experimental.distributions.marginal_fns.retrying_cholesky(x)[0],  # See commentary below.

# Take 4 samples from the GP at the index points.
s = gp.sample(4, seed=random.PRNGKey(0))
assert s.shape == (4, num_points)

# Compute the log likelihood of the sampled values.
lp = gp.log_prob(s)
assert lp.shape == (4,)

# GPs can also be instantiated without index points, in which case the index
# points must be passed to method calls.
gp_no_index_pts = tfd.GaussianProcess(
s = gp_no_index_pts.sample(index_points=index_points, seed=random.PRNGKey(0))

# Predictive GPs conditioned on observations can be built with
# `GaussianProcess.posterior_predictive`. The Flax module's
# `precompute_predictive` and `posterior_predictive` methods call this GP method.
gprm = gp.posterior_predictive(

# `gprm` is an instance of `tfd.GaussianProcessRegressionModel`. This class can
# also be instantiated directly (as a side note -- this isn't necessary for
# modeling with Vizier).
same_gprm = tfd.GaussianProcessRegressionModel(

Aside from the kernel, index points, and noise variance, note the cholesky_fn arg to the GaussianProcess constructor. cholesky_fn is a callable that takes a matrix and returns a Cholesky-like lower triangular factor. The default function adds a jitter of 1e-6 to the diagonal and then calls jnp.linalg.cholesky. An alternative, used in the Vizier GP, is tfp.experimental.distributions.marginal_fns.retrying_cholesky, which adds progressively larger jitter until the Cholesky decomposition succeeds.

A side note on batch shape in TFP

tl;dr: Don’t worry about batch shape.

TFP objects have a notion of batch shape, which is useful for vectorized computations. For more on this, see Understanding TensorFlow Distributions Shapes.

For the purposes of Bayesopt in Vizier, JAX’s vmap means that our TFP objects can have a single parameterization with empty batch shape. For example, in the following loss function takes a scalar amplitude, and the kernel and GP both have empty batch shape.

def loss_fn(amplitude):  # `a` is a scalar.
  k = tfpk.ExponentiatedQuadratic(amplitude=amplitude)  # batch shape []
  gp = tfd.GaussianProcess(k, index_points=index_points)   # batch shape []
  return -gp.log_prob(observations)

initial_amplitude = np.random.uniform(size=[50])

losses = jax.vmap(loss_fn)(initial_amplitude)
assert losses.shape == (50,)

We could also vectorize the loss computation by using a batched GP. In this simple case, the code is identical except that vmap is removed. Now, the kernel and GP represent a “batch” of kernels and GPs, each with different parameter values. Working with batch shape requires additional accounting on the part of the user to ensure that parameter shapes broadcast correctly, the correct dimensions are reduced over, etc. For Vizier’s use case, we find it simpler to rely on vmap.

def loss_fn(amplitude):  # `a` has shape [50].
  k = tfpk.ExponentiatedQuadratic(amplitude=amplitude)  # batch shape [50]
  gp = tfd.GaussianProcess(k, index_points=index_points)   # batch shape [50]
  return -gp.log_prob(observations)

initial_amplitude = np.random.uniform(size=[50])

# No vmap.
losses = loss_fn(initial_amplitude)
assert losses.shape == (50,)