Debugging tips
JAX
JAX’s has a number of useful debugging tools including:
jax.debug.print
to print values, even inside of jit-compiled code.jit-able runtime error checking with
jax.experimental.checkify
.jax_debug_nans
flag to automatically detect when NaNs are produced in jit-compiled code.disable_jit
, a context manager that disablesjit()
behavior.
TFP
TFP objects (bijectors, distributions, PSD kernels) have a
validate_args
boolean arg to__init__
. IfTrue
, it runs additional (possibly expensive) runtime checks, e.g. to verify that parameters likelength_scale
are nonnegative. In TFP, we enablevalidate_args
in unit tests, and use it as a debugging tool.Reproducibility: All functions and methods in TFP rely on random number generation, such as the
sample
method of distributions, take aseed
arg, which in JAX is an instance ofjax.random.PRNGKey
. This arg is mandatory in TFP-on-JAX, and ensures reproducible random number generation. See thejax.random
documentation for more details.Tests of sample statistics: TFP’s internal
test_util
module includesassertAllMeansClose
, which asserts that the mean of a sample is as expected, and diagnoses the statistical significance of failures.
#@title Imports
from jax import numpy as jnp, tree_util
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfpk = tfp.math.psd_kernels
# Demo of `validate_args`.
print('Without runtime arg validation, the kernel with negative amplitude happily builds.')
k = tfpk.MaternFiveHalves(amplitude=-1., validate_args=False)
print('With runtime arg validation:')
k = tfpk.MaternFiveHalves(amplitude=-1., validate_args=True)
What is “AutoCompositeTensor”?
You might have noticed that the base classes of the bijectors and PSD kernels are AutoCompositeTensorBijector
and AutoCompositeTensorPSDKernel
. In TensorFlow, objects that inherit from CompositeTensor
have a recipe that allows them to be flattened into collections of Tensors and rebuilt, so that they can cross tf.function
boundaries and interact with TF control flow similarly to Tensors (e.g., be passed in a while_loop
’s carried state). JAX has a similar notion called Pytree. Subclassing the AutoCompositeTensor*
versions of TFP base classes means that the class will be registered as a Pytree node (making use of shared CompositeTensor/Pytree machinery in TFP). For the Flax model to return a GP in JIT-compiled code, it’s necessary for the GP and its PSD kernel to be Pytrees.
gp = tfd.GaussianProcess(
tfpk.MaternFiveHalves(length_scale=jnp.ones([5])),
observation_noise_variance=jnp.array([0.5]))
gp_flat, gp_tree = tree_util.tree_flatten(gp)
print(f'GP flattened into arrays: {gp_flat}')
rebuilt_gp = tree_util.tree_unflatten(gp_tree, gp_flat)
assert isinstance(rebuilt_gp, tfd.GaussianProcess)