{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "W6GCupiD5r5C" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/vizier/blob/main/docs/advanced_topics/tfp/bijectors.ipynb)\n", "\n", "# Bijectors\n", "TFP [bijectors](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/bijectors) represent (mostly) invertible, smooth functions. For Bayesopt modeling in Vizier, they are used to:\n", "- to constrain parameter values for optimization in an unconstrained space.\n", "- For input warping or output warping (e.g. the [Yeo Johnson](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.yeojohnson.html) bijector).\n", "\n", "Each bijector implements at least 3 methods:\n", " * `forward`,\n", " * `inverse`, and\n", " * (at least) one of `forward_log_det_jacobian` and `inverse_log_det_jacobian`.\n", "\n", "When bijectors are used to transform distributions (with `tfd.TransformedDistribution`), the log det Jacobian ensures that the transformation is volume-preserving and the distribution's PDF still integrates to 1.\n", "\n", "Bijectors also cache the forward and inverse computations, and log-det-Jacobians. This has two purposes:\n", "- Avoid repeating potentially expensive computations (as with the `CholeskyOuterProduct` bijector).\n", "- Maintain numerical precision so that `b.inverse(b.forward(x)) == x`.\n", "Below is an illustration of preservation of numerical precision.\n", "\n", "Although TFP library bijectors are written in TensorFlow (and automatically converted to JAX with TFP's rewrite machinery), user-defined bijectors can be written in JAX directly. For example, a complete JAX reimplementation of the `Exp` bijector is below. TFP's library already contains an `Exp` bijector and it's JAX supported, so it isn't actually necessary to implement this.\n", "\n", "While it's rare that Vizier users will have to implement new TFP components, we include this as an example to show how it would be done using TFP's JAX backend, since all TFP library bijectors are written in TensorFlow." ] }, { "cell_type": "markdown", "metadata": { "id": "7D8qV28SP1pX" }, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "24HaTNrftvz7" }, "outputs": [], "source": [ "from jax import numpy as jnp\n", "import numpy as np\n", "from tensorflow_probability.substrates import jax as tfp\n", "\n", "tfd = tfp.distributions\n", "tfb = tfp.bijectors\n", "tfpk = tfp.math.psd_kernels" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mzGSCSrv3n0Q" }, "outputs": [], "source": [ "class Exp(tfb.AutoCompositeTensorBijector):\n", "\n", " def __init__(self,\n", " validate_args=False,\n", " name='exp'):\n", " \"\"\"Instantiates the `Exp` bijector.\"\"\"\n", " parameters = dict(locals())\n", " super(Exp, self).__init__(\n", " forward_min_event_ndims=0,\n", " validate_args=validate_args,\n", " parameters=parameters, # TODO(emilyaf): explain why this is necessary.\n", " name=name)\n", "\n", " @classmethod\n", " def _parameter_properties(cls, dtype):\n", " return dict()\n", "\n", " @classmethod\n", " def _is_increasing(cls):\n", " return True\n", "\n", " def _forward(self, x):\n", " return jnp.exp(x)\n", "\n", " def _inverse(self, y):\n", " return jnp.log(y)\n", "\n", " def _inverse_log_det_jacobian(self, y):\n", " return -jnp.log(y)\n", "\n", "# Make sure it gives the same results as the TFP library bijector.\n", "x = np.random.normal(size=[5])\n", "tfp_exp = tfb.Exp()\n", "my_exp = Exp()\n", "np.testing.assert_allclose(tfp_exp.forward(x), my_exp.forward(x))\n", "np.testing.assert_allclose(tfp_exp.forward_log_det_jacobian(x),\n", " my_exp.forward_log_det_jacobian(x), rtol=1e-6)" ] }, { "cell_type": "markdown", "metadata": { "id": "jBjXmTPm5R8a" }, "source": [ "TFP's bijector library includes:\n", "- Simple bijectors (for example, there are many more):\n", " - `Scale(k)` multiplies its input by `k`.\n", " - `Shift(k)` adds `k` to its input.\n", " - `Sigmoid()` computes the sigmoid function.\n", " - `FillScaleTriL()` packs its input, a vector, into a lower-triangular matrix.\n", " - ...\n", "- `Invert` wraps any bijector instance and swaps its forward and inverse methods, e.g. `inv_sigmoid = tfb.Invert(tfb.Sigmoid())`.\n", "- `Chain` composes a series of bijectors. The function $f(x) = 3 + 2x$ can be expressed as `tfb.Chain([tfb.Shift(3.), tfb.Scale(2.)])`. Note that the bijectors in the list are applied from right to left.\n", "- `JointMap` applies a nested structure of bijectors to an identical nested structure of inputs. `build_constraining_bijector`, shown above, returns a `JointMap` which applies a nested structure of bijectors to an identical nested structure of inputs. Vizier `get_constraints` function could be used to generate a `JointMap` based on the `Constraint`s of the `ModelParameter`s defined in the coroutine.\n", "- `Restructure` packs the elements of one nested structure (e.g. a list) into a different structure (e.g. a dict). `spm.build_restructure_bijector`, for example, is a `Chain` bijector that takes a vector of parameters, splits it into a list, and packs the elements of the list into a dictionary with the same structure as the Flax parameters dict." ] }, { "cell_type": "markdown", "metadata": { "id": "bWUSY9L2BKIZ" }, "source": [ "### Exercise: Bijectors\n", "Write a bijector (with `Chain`) that computes the function $f(x) = e^{x^2 + 1}$." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7l92sJ8v2Ajl" }, "outputs": [], "source": [ "b = tfb.Chain([...])\n", "\n", "f = lambda x: jnp.exp(x**2 + 1)\n", "x = np.random.normal(size=[5])\n", "np.testing.assert_allclose(f(x), b.forward(x))" ] }, { "cell_type": "markdown", "metadata": { "id": "WAKVc5mNQE-u" }, "source": [ "### Solution" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mZzhXbr4SD-W" }, "outputs": [], "source": [ "b = tfb.Chain([tfb.Exp(), tfb.Shift(1.), tfb.Square()])" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1bhe9vVJps8t8IsIU4sbInYvcQPgBlLWn", "timestamp": 1667943943389 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }