Open in Colab

Debugging tips

JAX

JAX’s has a number of useful debugging tools including:

  • jax.debug.print to print values, even inside of jit-compiled code.

  • jit-able runtime error checking with jax.experimental.checkify.

  • jax_debug_nans flag to automatically detect when NaNs are produced in jit-compiled code.

  • disable_jit, a context manager that disables jit() behavior.

TFP

  • TFP objects (bijectors, distributions, PSD kernels) have a validate_args boolean arg to __init__. If True, it runs additional (possibly expensive) runtime checks, e.g. to verify that parameters like length_scale are nonnegative. In TFP, we enable validate_args in unit tests, and use it as a debugging tool.

  • Reproducibility: All functions and methods in TFP rely on random number generation, such as the sample method of distributions, take a seed arg, which in JAX is an instance of jax.random.PRNGKey. This arg is mandatory in TFP-on-JAX, and ensures reproducible random number generation. See the jax.random documentation for more details.

  • Tests of sample statistics: TFP’s internal test_util module includes assertAllMeansClose, which asserts that the mean of a sample is as expected, and diagnoses the statistical significance of failures.

#@title Imports
from jax import numpy as jnp, tree_util
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfpk = tfp.math.psd_kernels
# Demo of `validate_args`.
print('Without runtime arg validation, the kernel with negative amplitude happily builds.')
k = tfpk.MaternFiveHalves(amplitude=-1., validate_args=False)
print('With runtime arg validation:')
k = tfpk.MaternFiveHalves(amplitude=-1., validate_args=True)

What is “AutoCompositeTensor”?

You might have noticed that the base classes of the bijectors and PSD kernels are AutoCompositeTensorBijector and AutoCompositeTensorPSDKernel. In TensorFlow, objects that inherit from CompositeTensor have a recipe that allows them to be flattened into collections of Tensors and rebuilt, so that they can cross tf.function boundaries and interact with TF control flow similarly to Tensors (e.g., be passed in a while_loop’s carried state). JAX has a similar notion called Pytree. Subclassing the AutoCompositeTensor* versions of TFP base classes means that the class will be registered as a Pytree node (making use of shared CompositeTensor/Pytree machinery in TFP). For the Flax model to return a GP in JIT-compiled code, it’s necessary for the GP and its PSD kernel to be Pytrees.

gp = tfd.GaussianProcess(
    tfpk.MaternFiveHalves(length_scale=jnp.ones([5])),
    observation_noise_variance=jnp.array([0.5]))
gp_flat, gp_tree = tree_util.tree_flatten(gp)
print(f'GP flattened into arrays: {gp_flat}')
rebuilt_gp = tree_util.tree_unflatten(gp_tree, gp_flat)
assert isinstance(rebuilt_gp, tfd.GaussianProcess)