{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "e-HthPasFSPD" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/vizier/blob/main/docs/advanced_topics/tfp/debugging.ipynb)\n", "\n", "# Debugging tips" ] }, { "cell_type": "markdown", "metadata": { "id": "NCrdmOTOnlyn" }, "source": [ "## JAX\n", "JAX's has a number of useful [debugging tools](https://jax.readthedocs.io/en/latest/debugging/index.html) including:\n", "\n", "- `jax.debug.print` to print values, even inside of jit-compiled code.\n", "- jit-able runtime error checking with `jax.experimental.checkify`.\n", "- `jax_debug_nans` flag to automatically detect when NaNs are produced in jit-compiled code.\n", "- [`disable_jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.disable_jit.html), a context manager that disables `jit()` behavior.\n", "\n", "## TFP\n", "- TFP objects (bijectors, distributions, PSD kernels) have a `validate_args` boolean arg to `__init__`. If `True`, it runs additional (possibly expensive) runtime checks, e.g. to verify that parameters like `length_scale` are nonnegative. In TFP, we enable `validate_args` in unit tests, and use it as a debugging tool.\n", "- Reproducibility: All functions and methods in TFP rely on random number generation, such as the `sample` method of distributions, take a `seed` arg, which in JAX is an instance of `jax.random.PRNGKey`. This arg is mandatory in TFP-on-JAX, and ensures reproducible random number generation. See the `jax.random` [documentation](https://jax.readthedocs.io/en/latest/jax.random.html) for more details.\n", "- Tests of sample statistics: TFP's internal `test_util` module includes [`assertAllMeansClose`](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/internal/test_util.py#L349), which asserts that the mean of a sample is as expected, and diagnoses the statistical significance of failures.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bE3flNf4unva" }, "outputs": [], "source": [ "#@title Imports\n", "from jax import numpy as jnp, tree_util\n", "from tensorflow_probability.substrates import jax as tfp\n", "\n", "tfd = tfp.distributions\n", "tfpk = tfp.math.psd_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "X5lYkaBDFTYw" }, "outputs": [], "source": [ "# Demo of `validate_args`.\n", "print('Without runtime arg validation, the kernel with negative amplitude happily builds.')\n", "k = tfpk.MaternFiveHalves(amplitude=-1., validate_args=False)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "w4d-80grBNI1" }, "outputs": [], "source": [ "print('With runtime arg validation:')\n", "k = tfpk.MaternFiveHalves(amplitude=-1., validate_args=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "2uatyzy-BgXp" }, "source": [ "## What is \"AutoCompositeTensor\"?\n", "You might have noticed that the base classes of the bijectors and PSD kernels are `AutoCompositeTensorBijector` and `AutoCompositeTensorPSDKernel`. In TensorFlow, objects that inherit from `CompositeTensor` have a recipe that allows them to be flattened into collections of Tensors and rebuilt, so that they can cross `tf.function` boundaries and interact with TF control flow similarly to Tensors (e.g., be passed in a `while_loop`'s carried state). JAX has a similar notion called [Pytree](https://jax.readthedocs.io/en/latest/pytrees.html). Subclassing the `AutoCompositeTensor*` versions of TFP base classes means that the class will be registered as a Pytree node (making use of shared CompositeTensor/Pytree machinery in TFP). For the Flax model to return a GP in JIT-compiled code, it's necessary for the GP and its PSD kernel to be Pytrees." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GnhVG6YIBqlG" }, "outputs": [], "source": [ "gp = tfd.GaussianProcess(\n", " tfpk.MaternFiveHalves(length_scale=jnp.ones([5])),\n", " observation_noise_variance=jnp.array([0.5]))\n", "gp_flat, gp_tree = tree_util.tree_flatten(gp)\n", "print(f'GP flattened into arrays: {gp_flat}')\n", "rebuilt_gp = tree_util.tree_unflatten(gp_tree, gp_flat)\n", "assert isinstance(rebuilt_gp, tfd.GaussianProcess)" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1bhe9vVJps8t8IsIU4sbInYvcQPgBlLWn", "timestamp": 1667943943389 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }