Bijectors
TFP bijectors represent (mostly) invertible, smooth functions. For Bayesopt modeling in Vizier, they are used to:
to constrain parameter values for optimization in an unconstrained space.
For input warping or output warping (e.g. the Yeo Johnson bijector).
Each bijector implements at least 3 methods:
forward
,inverse
, and(at least) one of
forward_log_det_jacobian
andinverse_log_det_jacobian
.
When bijectors are used to transform distributions (with tfd.TransformedDistribution
), the log det Jacobian ensures that the transformation is volume-preserving and the distribution’s PDF still integrates to 1.
Bijectors also cache the forward and inverse computations, and log-det-Jacobians. This has two purposes:
Avoid repeating potentially expensive computations (as with the
CholeskyOuterProduct
bijector).Maintain numerical precision so that
b.inverse(b.forward(x)) == x
. Below is an illustration of preservation of numerical precision.
Although TFP library bijectors are written in TensorFlow (and automatically converted to JAX with TFP’s rewrite machinery), user-defined bijectors can be written in JAX directly. For example, a complete JAX reimplementation of the Exp
bijector is below. TFP’s library already contains an Exp
bijector and it’s JAX supported, so it isn’t actually necessary to implement this.
While it’s rare that Vizier users will have to implement new TFP components, we include this as an example to show how it would be done using TFP’s JAX backend, since all TFP library bijectors are written in TensorFlow.
Imports
from jax import numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels
class Exp(tfb.AutoCompositeTensorBijector):
def __init__(self,
validate_args=False,
name='exp'):
"""Instantiates the `Exp` bijector."""
parameters = dict(locals())
super(Exp, self).__init__(
forward_min_event_ndims=0,
validate_args=validate_args,
parameters=parameters, # TODO(emilyaf): explain why this is necessary.
name=name)
@classmethod
def _parameter_properties(cls, dtype):
return dict()
@classmethod
def _is_increasing(cls):
return True
def _forward(self, x):
return jnp.exp(x)
def _inverse(self, y):
return jnp.log(y)
def _inverse_log_det_jacobian(self, y):
return -jnp.log(y)
# Make sure it gives the same results as the TFP library bijector.
x = np.random.normal(size=[5])
tfp_exp = tfb.Exp()
my_exp = Exp()
np.testing.assert_allclose(tfp_exp.forward(x), my_exp.forward(x))
np.testing.assert_allclose(tfp_exp.forward_log_det_jacobian(x),
my_exp.forward_log_det_jacobian(x), rtol=1e-6)
TFP’s bijector library includes:
Simple bijectors (for example, there are many more):
Scale(k)
multiplies its input byk
.Shift(k)
addsk
to its input.Sigmoid()
computes the sigmoid function.FillScaleTriL()
packs its input, a vector, into a lower-triangular matrix.…
Invert
wraps any bijector instance and swaps its forward and inverse methods, e.g.inv_sigmoid = tfb.Invert(tfb.Sigmoid())
.Chain
composes a series of bijectors. The function \(f(x) = 3 + 2x\) can be expressed astfb.Chain([tfb.Shift(3.), tfb.Scale(2.)])
. Note that the bijectors in the list are applied from right to left.JointMap
applies a nested structure of bijectors to an identical nested structure of inputs.build_constraining_bijector
, shown above, returns aJointMap
which applies a nested structure of bijectors to an identical nested structure of inputs. Vizierget_constraints
function could be used to generate aJointMap
based on theConstraint
s of theModelParameter
s defined in the coroutine.Restructure
packs the elements of one nested structure (e.g. a list) into a different structure (e.g. a dict).spm.build_restructure_bijector
, for example, is aChain
bijector that takes a vector of parameters, splits it into a list, and packs the elements of the list into a dictionary with the same structure as the Flax parameters dict.
Exercise: Bijectors
Write a bijector (with Chain
) that computes the function \(f(x) = e^{x^2 + 1}\).
b = tfb.Chain([...])
f = lambda x: jnp.exp(x**2 + 1)
x = np.random.normal(size=[5])
np.testing.assert_allclose(f(x), b.forward(x))
Solution
b = tfb.Chain([tfb.Exp(), tfb.Shift(1.), tfb.Square()])