Bayesian Optimization Modeling
The goal of this tutorial is to introduce Bayesian optimization workflows in OSS Vizier, including the underlying TensorFlow Probability (TFP) components and JAX/Flax functionality. The target audience is researchers and practitioners already well-versed in Bayesian optimization, who want to define and train their own Gaussian Process surrogate models for Bayesian optimization in OSS Vizier.
Additional resources for TFP
If you’re new to TFP, a good place to start is A tour of TensorFlow Probability. TFP began as a TensorFlow-only library, but now has a JAX backend that is entirely independent of TensorFlow (such that “Tensor-Friendly Probability” might be a better backronym). This Colab uses TFP’s JAX backend (see the “Imports” cell for how to import it).
Additional resources for Flax
OSS Vizier’s Bayesian Optimization models are defined as Flax modules.
Imports
import chex
import jax
from jax import numpy as jnp, random, tree_util
import numpy as np
import optax
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from tensorflow_probability.substrates import jax as tfp
from typing import Any
# Vizier models can freely access modules from vizier._src
from vizier._src.benchmarks.experimenters.synthetic import bbob
from vizier.jax import optimizers
from vizier._src.jax import stochastic_process_model as spm
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels
Defining a GP surrogate model and hyperparameters
To write a GP surrogate model, first write a coroutine that yields parameter specifications (ModelParameter
) and returns a GP distribution. Downstream, the parameter specifications are used to define Flax module parameters. The inputs to the coroutine function represent the index points of the GP (in the remainder of this Colab, we refer to “inputs” and “index points” interchangeably).
The rationale for the coroutine design is that it lets us automate the application of the parameter constraint and initialization functions (corresponding to hyperpriors, e.g.), and enables simultaneous specification of the model parameters and how their values are used to instantiate a GP.
Coroutine example
The following cell shows a coroutine defining a GP with a squared exponential kernel and two parameters: the length scale of the kernel and the observation noise variance of the GP.
def simple_gp_coroutine(inputs: chex.Array=None):
length_scale = yield spm.ModelParameter.from_prior(
tfd.Gamma(1., 1., name='length_scale'))
amplitude = 2. # Non-trainable parameters may be defined as constants.
kernel = tfpk.ExponentiatedQuadratic(
amplitude=amplitude, length_scale=length_scale)
observation_noise_variance = yield spm.ModelParameter(
init_fn=lambda x: jnp.exp(random.normal(x)),
constraint=spm.Constraint(bounds=(0.0, 100.0), bijector=tfb.Softplus()),
regularizer=lambda x: x**2,
name='observation_noise_variance')
return tfd.GaussianProcess(
kernel,
index_points=inputs,
observation_noise_variance=observation_noise_variance)
ModelParameter
ModelParameter
may be used to define hyperpriors.
Parameter specifications from priors
The length scale parameter has a Gamma prior. This is equivalent to defining a ModelParameter
with a regularizer that computes the Gamma negative log likelihood and an initialization function that samples from the Gamma distribution. As the constraint was not specified, a default one is assigned which is the “default event space bijector” of the TFP distribution (each TFP distribution has a constraining bijector that maps the real line to the support of the distribution).
Specifying parameters explicitly
Observation noise variance, which is passed to the Gaussian process and represents the scalar variance of zero-mean Gaussian noise in the observed labels, is not given a tfd.Distribution
prior. Instead, it has its initialization, constraining, and regularization functions defined individually. Note that the initialization function is in the constrained space.
Constraints
ModelParameter allows to define constraints on the model parameters using the ‘Constraint’ object which is initiated with a tuple of ‘bounds’ and ‘bijector’ function.
Though the constraints are defined as part of the ModelParameter the Flax model itself does not use them, but rather it expects to receive parameter values already in the constrained space. This means that it’s the responsibility of the user/optimizer to pass the GP parameter values that are already in the constrained space.
Exercise: Write a GP model
Write an ARD Gaussian Process model with three parameters: signal_variance
, length_scale
, and observation_noise_variance
. (This is a slightly simplified version of the Vizier GP.)
signal_variance
andobservation noise_variance
are both:regularized by the function \(f(x) = 0.01\log(x)^2\)
bounded to be positive.
signal_variance
parameterizes a Matern 5/2 kernel, where the amplitude of the kernel is the square root ofsignal_variance
. Usetfpk.MaternFiveHalves
.length_scale
has a \(LogNormal(0, 1)\) prior for each dimension. Assume there are 4 dimensions, and usetfd.Sample
to build a 4-dimensional distribution consisting of IID LogNormal distributions. (Note that thelength_scale
parameter is a vector – all other parameters are scalars.)In TFP, ARD kernels are implemented with
tfpk.FeatureScaled
, withscale_diag
representing the length scale along each dimension.
def vizier_gp_coroutine(inputs: chex.Array=None):
pass
Solution
data_dimensionality = 2
def vizier_gp_coroutine(inputs: chex.Array=None):
"""A coroutine that follows the `ModelCoroutine` protocol."""
signal_variance = yield spm.ModelParameter(
init_fn=lambda x: tfb.Softplus()(random.normal(x)),
constraint=spm.Constraint(bounds=(0.0, 100.0), bijector=tfb.Softplus()),
regularizer=lambda x: 0.01 * jnp.log(x)**2,
name='signal_variance')
length_scale = yield spm.ModelParameter.from_prior(
tfd.Sample(
tfd.LogNormal(loc=0., scale=1.),
sample_shape=[data_dimensionality],
name='length_scale'),
constraint=spm.Constraint(bounds=(0.0, None)))
kernel = tfpk.MaternFiveHalves(
amplitude=jnp.sqrt(signal_variance), validate_args=True)
kernel = tfpk.FeatureScaled(
kernel, scale_diag=length_scale, validate_args=True)
observation_noise_variance = yield spm.ModelParameter(
init_fn=lambda x: jnp.exp(random.normal(x)),
constraint=spm.Constraint(bounds=(0.0, 100.0), bijector=tfb.Softplus()),
regularizer=lambda x: 0.01 * jnp.log(x)**2,
name='observation_noise_variance')
return tfd.GaussianProcess(
kernel=kernel,
index_points=inputs,
observation_noise_variance=observation_noise_variance,
validate_args=True)
To build a GP Flax module, instantiate a StochasticProcessModel
with a GP coroutine as shown below. The module runs the coroutine in the setup
and __call__
methods to initialize the parameters and then instantiate the GP object with the given parameters.
Recall that Flax modules have two primary methods: init
, which initializes parameters, and apply
, which computes the model’s forward pass given a set of parameters and input data.
model = spm.StochasticProcessModel(coroutine=vizier_gp_coroutine)
# Sample some fake data.
# Assume we have `num_points` observations, each with `dim` features.
num_points = 12
# Sample a set of index points.
index_points = np.random.normal(
size=[num_points, data_dimensionality]).astype(np.float32)
# Sample function values observed at the index points
observations = np.random.normal(size=[num_points]).astype(np.float32)
# Call the Flax module's `init` method to obtain initial parameter values.
init_params = model.init(random.PRNGKey(0), index_points)
We can observe the initial parameters values of the Flax model and see that they match with the ‘ModelParameter’ definitions in our coroutine.
print(init_params['params'])
To instantiate a GP with a set of parameters and index points, use the Flax module’s apply
method. apply
also returns the regularization losses
for the parameters, in mutables
. The regularization losses
are treated as mutable state because they are recomputed internally with each forward pass of the model. For more on mutable state in Flax, see this tutorial.
gp, mutables = model.apply(
init_params,
index_points,
mutable=['losses'])
assert isinstance(gp, tfd.GaussianProcess)
Optimizing hyperparameters
Exercise: Loss function
Write down a loss function that takes a parameters dict and returns the loss value, using model.apply
. The function will close over the observed data.
The loss should be the sum of the GP negative log likelihood and the regularization losses. The regularization loss values are computed when the module is called, using the ModelParameter
regularization functions. They are stored in a mutable variable collection called "losses"
, using the Flax method sow
.
def loss_fn(params):
...
return loss, {} # Return an empty dict as auxiliary state.
Solution
def loss_fn(params):
gp, mutables = model.apply({'params': params},
index_points,
mutable=['losses'])
loss = (-gp.log_prob(observations) +
jax.tree_util.tree_reduce(jnp.add, mutables['losses'])) # add the regularization losses.
return loss, {}
The gradients of the loss have the same structure as the params
dict.
grads = jax.grad(loss_fn, has_aux=True)(init_params['params'])[0]
print(grads)
We can use jax.tree_util
to take a step along the gradient (though in practice, with Optax, we can use update
and apply_updates
to update the parameters at each train step).
learning_rate = 1e-3
updated_params = jax.tree_util.tree_map(
lambda p, g: p - learning_rate * g,
init_params['params'],
grads)
print(updated_params)
Optimize hyperparameters with Vizier optimizers
Flax modules are often optimized using Optax which requires the developer to write a routine that initializes parameter values and then repeatedly computes the loss function gradients and updates the parameter values accordingly.
Vizier Optimizers is a library of optimizers that automate the process of finding the optimal Flax parameter values and wrap optimizers from libraries such as Optax and Jaxopt in a common interface. To use a Vizier Optimizer you have to specify the following:
setup
function which is used to generate the initial parameter values.loss_fn
function which is used for computing the loss function value and gradients. For example, the loss function of a GP model would be a marginal likelihood plus the parameters regularizations.rng
PRNGKey for controlling pseudo randomization.constraints
on the parameters (optional).
Below we use the Vizier JaxoptLbfgsB
optimizer to run a constrained L-BFGS-B algorithm. Unconstrainted optimizers (e.g. Adam) use a bijector function to map between the unconstrained space where the search is performed, and the constrained space where the loss function is evaluated. On the contrary, constrained optimizers (e.g. L-BGFS-B) use the constraint bounds directly in the search process.
To pass the constraints bounds to the JaxoptLbfgsB
optimizer we use the spm.get_constraints
function that traverse the parameters defined in the module coroutine and extract their bounds.
setup = lambda rng: model.init(rng, index_points)['params']
model_optimizer = optimizers.JaxoptLbfgsB(
random_restarts=20, best_n=None
)
constraints = spm.get_constraints(model)
optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),
constraints=constraints)
Predict on new inputs, conditional on observations
To compute the posterior predictive GP on unseen points, conditioned on observed data, use the precompute_predictive
and posterior_predictive
methods of the Flax module. precompute_predictive
must be called first; it runs and stores the Cholesky decomposition of the kernel matrix for the observed data. posterior_predictive
then returns a posterior predictive GP at new index points, avoiding recomputation of the Cholesky.
# Precompute the Cholesky.
_, pp_state = model.apply(
{'params': optimal_params},
index_points,
observations,
mutable=['predictive'],
method=model.precompute_predictive)
# Predict on new index points.
predictive_index_points = np.random.normal(
size=[5, data_dimensionality]).astype(np.float32)
pp_dist = model.apply(
{'params': optimal_params, **pp_state},
predictive_index_points,
index_points,
observations,
method=model.posterior_predictive)
# `posterior_predictive` returns a TFP distribution, whose mean, variance, and
# samples we can use to compute an acquisition function.
assert pp_dist.mean().shape == (5,)
Optimize a black-box function
For an end-to-end example of Bayesian optimization, we’ll use the GP surrogate model defined above along with an Upper Confidence Bound acquisition function to try to find the maximum of the Weierstrass function. First, visualize the function surface.
# Use the Weierstrass function from Vizier's Black-Box Optimization Benchmarking
# (BBOB) library.
bb_fun = bbob.Weierstrass
# Sample a set of index points in a 2D space.
num_points = 6
max_x = np.array(2.).astype(np.float32)
index_points = random.uniform(
random.PRNGKey(3),
shape=[num_points, data_dimensionality], dtype=jnp.float32) * max_x
# Compute function values observed at the index points.
observations = np.apply_along_axis(
bb_fun, axis=1, arr=index_points).astype(np.float32)
# Define a grid of points in the function domain for plotting.
n_grid = 100
x = y = np.linspace(0, max_x, n_grid, dtype=np.float32)
X, Y = np.meshgrid(x, y)
x_grid = np.vstack([X.ravel(), Y.ravel()]).T
y_grid = np.apply_along_axis(bb_fun, axis=1, arr=x_grid)
Z = y_grid.reshape(X.shape)
# Plot the black-box function values.
fig = plt.figure(figsize=(8, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, alpha=0.5)
ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r',
label='Initial observed data')
plt.title('Black-box (Weierstrass) function values and observed data')
plt.legend()
plt.show()
Next, run a few iterations of Bayesian optimization to maximize the black-box function given the observed data. A single iteration consists of the following steps:
Optimize the GP hyperparameters.
Find a suggestion that maximizes an Upper Confidence Bound acquisition function. In this example, we use grid search for the optimization.
Evaluate the black-box function on the suggestion and append it to the set of observed data.
(Note that this simple Bayesopt algorithm is for educational purposes and that we’d expect Vizier’s GP bandit algorithm to give better results.)
num_bayesopt_iter = 5
# At each iteration, redefine the loss function given the current observed data.
def build_loss_fn(index_points, observations):
def loss_fn(params):
gp, mutables = model.apply({'params': params},
index_points,
mutable=['losses'])
loss = (-gp.log_prob(observations) +
jax.tree_util.tree_reduce(jnp.add, mutables['losses'])) # add the regularization losses.
return loss, {}
return loss_fn
for i in range(num_bayesopt_iter):
# Update the loss function to condition on all observed data.
loss_fn = build_loss_fn(index_points, observations)
# Optimize the GP hyperparameters.
optimal_params, _ = model_optimizer(setup, loss_fn, random.PRNGKey(0),
constraints=constraints)
# Compute the posterior predictive distribution over a grid of points in the
# function domain (x_grid).
_, pp_state = model.apply(
{'params': optimal_params},
index_points,
observations,
mutable=['predictive'],
method=model.precompute_predictive)
pp_dist = model.apply(
{'params': optimal_params, **pp_state},
x_grid,
index_points,
observations,
method=model.posterior_predictive)
# Compute the acquisition function value at each point in the grid.
pred_mean = pp_dist.mean()
ucb_vec = pred_mean + 2. * pp_dist.stddev()
# Find the grid point with the highest acquisition function value.
ind = np.argmax(ucb_vec)
# Evaluate the black box function at the selected point.
f_val = bb_fun(x_grid[ind])
# Visualize the surrogate model mean and acquisition function surface at this
# iteration.
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(121, projection='3d')
W = pred_mean.reshape(X.shape)
ax.plot_surface(X, Y, W, alpha=0.5)
ax.scatter(index_points[:, 0], index_points[:, 1], observations, color='r',
label='Observed data')
ax.set_title('Observed data and posterior predictive GP mean')
ax.legend()
ax = fig.add_subplot(122, projection='3d')
ucb = ucb_vec.reshape(X.shape)
ax.plot_surface(X, Y, ucb, alpha=0.5)
ax.scatter(*x_grid[ind], ucb_vec[ind], color='r', label='New suggestion')
ax.set_title('Acquisition function')
ax.legend()
plt.show()
# Append the new suggestion and function value to the set of observations.
index_points = np.concatenate([index_points, x_grid[ind][np.newaxis]])
observations = np.concatenate(
[observations, np.array(f_val).astype(np.float32)[np.newaxis]])
print(f'Iteration: {i}')
print(f'Acquisition function value at suggestion: {ucb_vec[ind]}')
print(f'Black-box function value at suggestion: {f_val}')
Deeper dive on selected topics in TFP
As shown above, the Flax GP model makes use of a number of TFP components:
Distributions specify parameter priors (e.g.
tfd.Gamma
). The stochastic process model itself is also a TFP distribution,tfd.GaussianProcess
.Bijectors (e.g.
tfb.Softplus
) are used to constrain parameters for optimization, and may also be used for input/output warping.PSD kernels (e.g.
tfpk.ExponentiatedQuadratic
) specify the kernel function for the stochastic process.
The next sections of this Colab introduce these and how they’re used in Bayesopt modeling.
tfd.GaussianProcess
and friends
The stochastic process Flax modules return a TFP distribution in the Gaussian Process family (an instance of tfd.GaussianProcess
, tfd.StudentTProcess
, or tfde.MultiTaskGaussianProcess
).
This Colab doesn’t go into detail on TFP distributions, since advanced usage and implementation of distributions is rarely required for Bayesopt modeling with Vizier. For an overview of TFP distributions, see TensorFlow Distributions: A Gentle Introduction.
Some of the methods of the Gaussian Process distribution are demonstrated below. Gaussian Process Regression in TFP is also worth reading.
# Build a kernel function (see "PSD kernels" section below) and GP.
num_points = 6
index_points = random.uniform(
random.PRNGKey(3),
shape=[num_points, data_dimensionality], dtype=jnp.float32)
observations = random.uniform(
random.PRNGKey(4),
shape=[num_points], dtype=jnp.float32)
kernel = tfpk.MaternFiveHalves(
amplitude=2.,
length_scale=0.3,
validate_args=True # Run additional runtime checks; possibly expensive.
)
observation_noise_variance = jnp.ones([], dtype=observations.dtype)
gp = tfd.GaussianProcess(
kernel,
index_points=index_points,
observation_noise_variance=observation_noise_variance,
cholesky_fn=lambda x: tfp.experimental.distributions.marginal_fns.retrying_cholesky(x)[0], # See commentary below.
validate_args=True)
# Take 4 samples from the GP at the index points.
s = gp.sample(4, seed=random.PRNGKey(0))
assert s.shape == (4, num_points)
# Compute the log likelihood of the sampled values.
lp = gp.log_prob(s)
assert lp.shape == (4,)
# GPs can also be instantiated without index points, in which case the index
# points must be passed to method calls.
gp_no_index_pts = tfd.GaussianProcess(
kernel,
observation_noise_variance=observation_noise_variance)
s = gp_no_index_pts.sample(index_points=index_points, seed=random.PRNGKey(0))
# Predictive GPs conditioned on observations can be built with
# `GaussianProcess.posterior_predictive`. The Flax module's
# `precompute_predictive` and `posterior_predictive` methods call this GP method.
gprm = gp.posterior_predictive(
observations=observations,
predictive_index_points=predictive_index_points)
# `gprm` is an instance of `tfd.GaussianProcessRegressionModel`. This class can
# also be instantiated directly (as a side note -- this isn't necessary for
# modeling with Vizier).
same_gprm = tfd.GaussianProcessRegressionModel(
kernel,
index_points=predictive_index_points,
observation_index_points=index_points,
observations=observations,
observation_noise_variance=observation_noise_variance)
Aside from the kernel, index points, and noise variance, note the cholesky_fn
arg to the GaussianProcess
constructor.
cholesky_fn
is a callable that takes a matrix and returns a Cholesky-like lower triangular factor. The default function adds a jitter of 1e-6 to the diagonal and then calls jnp.linalg.cholesky
. An alternative, used in the Vizier GP, is tfp.experimental.distributions.marginal_fns.retrying_cholesky
, which adds progressively larger jitter until the Cholesky decomposition succeeds.
A side note on batch shape in TFP
tl;dr: Don’t worry about batch shape.
TFP objects have a notion of batch shape, which is useful for vectorized computations. For more on this, see Understanding TensorFlow Distributions Shapes.
For the purposes of Bayesopt in Vizier, JAX’s vmap
means that our TFP objects can have a single parameterization with empty batch shape. For example, in the following loss function takes a scalar amplitude
, and the kernel and GP both have empty batch shape.
def loss_fn(amplitude): # `a` is a scalar.
k = tfpk.ExponentiatedQuadratic(amplitude=amplitude) # batch shape []
gp = tfd.GaussianProcess(k, index_points=index_points) # batch shape []
return -gp.log_prob(observations)
initial_amplitude = np.random.uniform(size=[50])
losses = jax.vmap(loss_fn)(initial_amplitude)
assert losses.shape == (50,)
We could also vectorize the loss computation by using a batched GP. In this simple case, the code is identical except that vmap
is removed. Now, the kernel and GP represent a “batch” of kernels and GPs, each with different parameter values. Working with batch shape requires additional accounting on the part of the user to ensure that parameter shapes broadcast correctly, the correct dimensions are reduced over, etc. For Vizier’s use case, we find it simpler to rely on vmap
.
def loss_fn(amplitude): # `a` has shape [50].
k = tfpk.ExponentiatedQuadratic(amplitude=amplitude) # batch shape [50]
gp = tfd.GaussianProcess(k, index_points=index_points) # batch shape [50]
return -gp.log_prob(observations)
initial_amplitude = np.random.uniform(size=[50])
# No vmap.
losses = loss_fn(initial_amplitude)
assert losses.shape == (50,)