{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "StyZfj4JCmef" }, "source": [ "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/vizier/blob/main/docs/advanced_topics/tfp/kernels.ipynb)\n", "\n", "# PSD kernels\n", "TFP's [PSD kernels](https://github.com/tensorflow/probability/tree/main/tensorflow_probability/python/math/psd_kernels) compute positive semidefinite kernel functions. A PSD kernel instance is a required arg to TFP's Gaussian Process distribution, so specifying a GP model coroutine will generally involve defining a PSD kernel as an intermediate.\n", "\n", "PSD kernel subclasses take hyperparameters, such as amplitude and length scale, as constructor args. They have three primary public methods: `apply`, `matrix`, and `tensor`, each of which computes the kernel function pairwise on inputs in different ways:\n", "\n", "- `apply` computes the value of the kernel function at a pair of (batches of)\n", " input locations. It's the only required method for subclasses: `matrix` and `tensor`\n", " are implemented in terms of `apply` (except when a more efficient method exists to compure pairwise kernel matrices).\n", "\n", "- `matrix` computes the value of the kernel *pairwise* on two (batches of)\n", " lists of input examples. When the two collections are the same the result is\n", " called the [Gram matrix](https://en.wikipedia.org/wiki/Gramian_matrix). `matrix` is the most important method for GPs.\n", "\n", "- `tensor` generalizes `matrix`, taking rank `k1` and `k2` collections of\n", " input examples to a rank `k1 + k2` collection of kernel values. (We mention `tensor` for completeness, but it isn't relevant to GPs).\n", "\n", "PSD kernels have somewhat complex [shape semantics](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/math/psd_kernels/positive_semidefinite_kernel.py#L97), due to the need to define which input dimensions should be included in pairwise computations and which should be treated as batch dimensions (denoting independent sets of input points.)" ] }, { "cell_type": "markdown", "metadata": { "id": "nUd5BnyzQRzU" }, "source": [ "### Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1VlROI1ftTXt" }, "outputs": [], "source": [ "from jax import numpy as jnp\n", "import numpy as np\n", "from tensorflow_probability.python.internal import dtype_util\n", "from tensorflow_probability.substrates import jax as tfp\n", "\n", "tfd = tfp.distributions\n", "tfpk = tfp.math.psd_kernels" ] }, { "cell_type": "markdown", "metadata": { "id": "82kaHCwktsRs" }, "source": [ "\n", "\n", "Some examples of PSD kernel usage:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "B5h2fHYJSjc3" }, "outputs": [], "source": [ "# Construct a MaternFiveHalves kernel (with empty batch shape).\n", "amplitude = 2.\n", "length_scale = 0.5\n", "k = tfpk.MaternFiveHalves(\n", " amplitude=amplitude, length_scale=length_scale)\n", "\n", "# Randomly sample some input data.\n", "num_features = 5\n", "num_observations = 12\n", "x = np.random.normal(size=[num_observations, num_features])\n", "\n", "# `matrix` computes pairwise kernel values for the Cartesian product over the\n", "# second-to-rightmost dimension of the inputs. Following the terminology in the\n", "# PSD kernel docstring, there is a single example dimension (and single feature\n", "# dimension).\n", "assert k.matrix(x, x).shape == (12, 12)\n", "\n", "# Calling `matrix` on inputs of shape [12, d] and [10, d] results in a kernel\n", "# matrix of shape (12, 10)\n", "y = np.random.normal(size=[10, num_features])\n", "assert k.matrix(x, y).shape == (12, 10)" ] }, { "cell_type": "markdown", "metadata": { "id": "KBKVTALlATge" }, "source": [ "ARD kernels in TFP are implemented with the `FeatureScaled` kernel." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wHaEcbmMAfNz" }, "outputs": [], "source": [ "length_scale = np.random.uniform(size=[num_features])\n", "ard_kernel = tfpk.FeatureScaled(\n", " tfpk.MaternFiveHalves(amplitude=np.float64(0.3)),\n", " scale_diag=length_scale)" ] }, { "cell_type": "markdown", "metadata": { "id": "as-qo57xZDkp" }, "source": [ "Sums and products of PSD kernels are easy to compute, via operator overloading." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ygjzAIEAH4gM" }, "outputs": [], "source": [ "matern = tfpk.MaternFiveHalves(amplitude=2.)\n", "squared_exponential = tfpk.ExponentiatedQuadratic(length_scale=0.1)\n", "sum_kernel = matern + squared_exponential\n", "\n", "np.testing.assert_allclose(\n", " sum_kernel.matrix(x, x),\n", " matern.matrix(x, x) + squared_exponential.matrix(x, x))" ] }, { "cell_type": "markdown", "metadata": { "id": "2LBRj37hlWb2" }, "source": [ "### Exercise: Implemented a squared exponential kernel\n", "As an exercise, try implementing a squared exponential PSD kernel:\n", "```\n", "k(x, y) = amplitude**2 * exp(-||x - y||**2 / (2 * length_scale**2))\n", "```\n", "\n", "In TFP library kernels (see TFP's [squared exponential kernel](https://github.com/tensorflow/probability/blob/main/tensorflow_probability/python/math/psd_kernels/exponentiated_quadratic.py)), there are other details to consider, like handling of different dtypes, accepting either `length_scale` or `inverse_length_scale`, and ensuring that kernel batch shapes broadcast correctly with inputs.\n", "\n", "For the purpose of the exercise we can ignore these, and `apply` can be written as a straightforward implementation of the kernel function. (New PSD kernels added to TFP would have to treat this more carefully, and existing kernels serve as good guides).\n", "\n", "Try implementing `_apply` below (the solution is a couple cells down)." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "vQYQ8TIqaL-N" }, "outputs": [], "source": [ "class MyExponentiatedQuadratic(tfpk.AutoCompositeTensorPsdKernel):\n", "\n", " def __init__(self,\n", " amplitude,\n", " length_scale):\n", " self.amplitude = amplitude\n", " self.length_scale = length_scale\n", " super(MyExponentiatedQuadratic, self).__init__(\n", " feature_ndims=1,\n", " dtype=jnp.float32,\n", " name='MyExponentiatedQuadratic',\n", " validate_args=False)\n", "\n", " @classmethod\n", " def _parameter_properties(cls, dtype):\n", " # All TFP objects have parameter properties, which contain information on\n", " # the shape and domain of the parameters. The Softplus bijector is\n", " # associated with both the amplitude and length scale parameters, and may be\n", " # used to constrain these parameters to be positive. These bijectors are NOT\n", " # automatically applied when the kernel is called -- users may apply them\n", " # explicitly when doing unconstrained parameter optimization, e.g.\n", " return dict(\n", " amplitude=parameter_properties.ParameterProperties(\n", " default_constraining_bijector_fn=(\n", " lambda: tfb.Softplus(low=dtype_util.eps(dtype)))),\n", " length_scale=parameter_properties.ParameterProperties(\n", " default_constraining_bijector_fn=(\n", " lambda: tfb.Softplus(low=dtype_util.eps(dtype)))))\n", "\n", " def _apply(self, x1, x2, example_ndims=0):\n", " del example_ndims # Can ignore this arg.\n", " pass\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ayY5ocKVK5RP" }, "source": [ "Make sure this kernel gives the same output as `ExponentiatedQuadratic` in the TFP library." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "TfE5QMvya-0v" }, "outputs": [], "source": [ "my_kernel = MyExponentiatedQuadratic(amplitude=2., length_scale=0.5)\n", "tfp_kernel = tfpk.ExponentiatedQuadratic(amplitude=2., length_scale=0.5)\n", "np.testing.assert_allclose(my_kernel.matrix(x, y), tfp_kernel.matrix(x, y), rtol=1e-5)" ] }, { "cell_type": "markdown", "metadata": { "id": "pYHAZqu0QVHs" }, "source": [ "### Solution" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "M_qZuCnImDTx" }, "outputs": [], "source": [ "class MyExponentiatedQuadratic(tfpk.AutoCompositeTensorPsdKernel):\n", "\n", " def __init__(self,\n", " amplitude,\n", " length_scale):\n", " self.amplitude = amplitude\n", " self.length_scale = length_scale\n", " super(MyExponentiatedQuadratic, self).__init__(\n", " feature_ndims=1,\n", " dtype=jnp.float32,\n", " name='MyExponentiatedQuadratic',\n", " validate_args=False)\n", "\n", " @classmethod\n", " def _parameter_properties(cls, dtype):\n", " return dict(\n", " amplitude=parameter_properties.ParameterProperties(\n", " default_constraining_bijector_fn=(\n", " lambda: tfb.Softplus(low=dtype_util.eps(dtype)))),\n", " length_scale=parameter_properties.ParameterProperties(\n", " default_constraining_bijector_fn=(\n", " lambda: tfb.Softplus(low=dtype_util.eps(dtype)))))\n", "\n", " def _apply(self, x1, x2, example_ndims=0):\n", " del example_ndims\n", " pairwise_sq_distance = jnp.sum((x1 - x2)**2, axis=-1)\n", " return jnp.exp(-0.5 * pairwise_sq_distance / self.length_scale ** 2) * self.amplitude ** 2\n" ] } ], "metadata": { "colab": { "last_runtime": { "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, "provenance": [ { "file_id": "1bhe9vVJps8t8IsIU4sbInYvcQPgBlLWn", "timestamp": 1667943943389 } ], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }