Open in Colab

Bijectors

TFP bijectors represent (mostly) invertible, smooth functions. For Bayesopt modeling in Vizier, they are used to:

  • to constrain parameter values for optimization in an unconstrained space.

  • For input warping or output warping (e.g. the Yeo Johnson bijector).

Each bijector implements at least 3 methods:

  • forward,

  • inverse, and

  • (at least) one of forward_log_det_jacobian and inverse_log_det_jacobian.

When bijectors are used to transform distributions (with tfd.TransformedDistribution), the log det Jacobian ensures that the transformation is volume-preserving and the distribution’s PDF still integrates to 1.

Bijectors also cache the forward and inverse computations, and log-det-Jacobians. This has two purposes:

  • Avoid repeating potentially expensive computations (as with the CholeskyOuterProduct bijector).

  • Maintain numerical precision so that b.inverse(b.forward(x)) == x. Below is an illustration of preservation of numerical precision.

Although TFP library bijectors are written in TensorFlow (and automatically converted to JAX with TFP’s rewrite machinery), user-defined bijectors can be written in JAX directly. For example, a complete JAX reimplementation of the Exp bijector is below. TFP’s library already contains an Exp bijector and it’s JAX supported, so it isn’t actually necessary to implement this.

While it’s rare that Vizier users will have to implement new TFP components, we include this as an example to show how it would be done using TFP’s JAX backend, since all TFP library bijectors are written in TensorFlow.

Imports

from jax import numpy as jnp
import numpy as np
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfb = tfp.bijectors
tfpk = tfp.math.psd_kernels
class Exp(tfb.AutoCompositeTensorBijector):

  def __init__(self,
               validate_args=False,
               name='exp'):
    """Instantiates the `Exp` bijector."""
    parameters = dict(locals())
    super(Exp, self).__init__(
        forward_min_event_ndims=0,
        validate_args=validate_args,
        parameters=parameters,  # TODO(emilyaf): explain why this is necessary.
        name=name)

  @classmethod
  def _parameter_properties(cls, dtype):
    return dict()

  @classmethod
  def _is_increasing(cls):
    return True

  def _forward(self, x):
      return jnp.exp(x)

  def _inverse(self, y):
      return jnp.log(y)

  def _inverse_log_det_jacobian(self, y):
    return -jnp.log(y)

# Make sure it gives the same results as the TFP library bijector.
x = np.random.normal(size=[5])
tfp_exp = tfb.Exp()
my_exp = Exp()
np.testing.assert_allclose(tfp_exp.forward(x), my_exp.forward(x))
np.testing.assert_allclose(tfp_exp.forward_log_det_jacobian(x),
                           my_exp.forward_log_det_jacobian(x), rtol=1e-6)

TFP’s bijector library includes:

  • Simple bijectors (for example, there are many more):

    • Scale(k) multiplies its input by k.

    • Shift(k) adds k to its input.

    • Sigmoid() computes the sigmoid function.

    • FillScaleTriL() packs its input, a vector, into a lower-triangular matrix.

  • Invert wraps any bijector instance and swaps its forward and inverse methods, e.g. inv_sigmoid = tfb.Invert(tfb.Sigmoid()).

  • Chain composes a series of bijectors. The function \(f(x) = 3 + 2x\) can be expressed as tfb.Chain([tfb.Shift(3.), tfb.Scale(2.)]). Note that the bijectors in the list are applied from right to left.

  • JointMap applies a nested structure of bijectors to an identical nested structure of inputs. build_constraining_bijector, shown above, returns a JointMap which applies a nested structure of bijectors to an identical nested structure of inputs. Vizier get_constraints function could be used to generate a JointMap based on the Constraints of the ModelParameters defined in the coroutine.

  • Restructure packs the elements of one nested structure (e.g. a list) into a different structure (e.g. a dict). spm.build_restructure_bijector, for example, is a Chain bijector that takes a vector of parameters, splits it into a list, and packs the elements of the list into a dictionary with the same structure as the Flax parameters dict.

Exercise: Bijectors

Write a bijector (with Chain) that computes the function \(f(x) = e^{x^2 + 1}\).

b = tfb.Chain([...])

f = lambda x: jnp.exp(x**2 + 1)
x = np.random.normal(size=[5])
np.testing.assert_allclose(f(x), b.forward(x))

Solution

b = tfb.Chain([tfb.Exp(), tfb.Shift(1.), tfb.Square()])