Open in Colab

PSD kernels

TFP’s PSD kernels compute positive semidefinite kernel functions. A PSD kernel instance is a required arg to TFP’s Gaussian Process distribution, so specifying a GP model coroutine will generally involve defining a PSD kernel as an intermediate.

PSD kernel subclasses take hyperparameters, such as amplitude and length scale, as constructor args. They have three primary public methods: apply, matrix, and tensor, each of which computes the kernel function pairwise on inputs in different ways:

  • apply computes the value of the kernel function at a pair of (batches of) input locations. It’s the only required method for subclasses: matrix and tensor are implemented in terms of apply (except when a more efficient method exists to compure pairwise kernel matrices).

  • matrix computes the value of the kernel pairwise on two (batches of) lists of input examples. When the two collections are the same the result is called the Gram matrix. matrix is the most important method for GPs.

  • tensor generalizes matrix, taking rank k1 and k2 collections of input examples to a rank k1 + k2 collection of kernel values. (We mention tensor for completeness, but it isn’t relevant to GPs).

PSD kernels have somewhat complex shape semantics, due to the need to define which input dimensions should be included in pairwise computations and which should be treated as batch dimensions (denoting independent sets of input points.)

Imports

from jax import numpy as jnp
import numpy as np
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.substrates import jax as tfp

tfd = tfp.distributions
tfpk = tfp.math.psd_kernels

Some examples of PSD kernel usage:

# Construct a MaternFiveHalves kernel (with empty batch shape).
amplitude = 2.
length_scale = 0.5
k = tfpk.MaternFiveHalves(
    amplitude=amplitude, length_scale=length_scale)

# Randomly sample some input data.
num_features = 5
num_observations = 12
x = np.random.normal(size=[num_observations, num_features])

# `matrix` computes pairwise kernel values for the Cartesian product over the
# second-to-rightmost dimension of the inputs. Following the terminology in the
# PSD kernel docstring, there is a single example dimension (and single feature
# dimension).
assert k.matrix(x, x).shape == (12, 12)

# Calling `matrix` on inputs of shape [12, d] and [10, d] results in a kernel
# matrix of shape (12, 10)
y = np.random.normal(size=[10, num_features])
assert k.matrix(x, y).shape == (12, 10)

ARD kernels in TFP are implemented with the FeatureScaled kernel.

length_scale = np.random.uniform(size=[num_features])
ard_kernel = tfpk.FeatureScaled(
    tfpk.MaternFiveHalves(amplitude=np.float64(0.3)),
    scale_diag=length_scale)

Sums and products of PSD kernels are easy to compute, via operator overloading.

matern = tfpk.MaternFiveHalves(amplitude=2.)
squared_exponential = tfpk.ExponentiatedQuadratic(length_scale=0.1)
sum_kernel = matern + squared_exponential

np.testing.assert_allclose(
    sum_kernel.matrix(x, x),
    matern.matrix(x, x) + squared_exponential.matrix(x, x))

Exercise: Implemented a squared exponential kernel

As an exercise, try implementing a squared exponential PSD kernel:

k(x, y) = amplitude**2 * exp(-||x - y||**2 / (2 * length_scale**2))

In TFP library kernels (see TFP’s squared exponential kernel), there are other details to consider, like handling of different dtypes, accepting either length_scale or inverse_length_scale, and ensuring that kernel batch shapes broadcast correctly with inputs.

For the purpose of the exercise we can ignore these, and apply can be written as a straightforward implementation of the kernel function. (New PSD kernels added to TFP would have to treat this more carefully, and existing kernels serve as good guides).

Try implementing _apply below (the solution is a couple cells down).

class MyExponentiatedQuadratic(tfpk.AutoCompositeTensorPsdKernel):

  def __init__(self,
               amplitude,
               length_scale):
    self.amplitude = amplitude
    self.length_scale = length_scale
    super(MyExponentiatedQuadratic, self).__init__(
        feature_ndims=1,
        dtype=jnp.float32,
        name='MyExponentiatedQuadratic',
        validate_args=False)

  @classmethod
  def _parameter_properties(cls, dtype):
    # All TFP objects have parameter properties, which contain information on
    # the shape and domain of the parameters. The Softplus bijector is
    # associated with both the amplitude and length scale parameters, and may be
    # used to constrain these parameters to be positive. These bijectors are NOT
    # automatically applied when the kernel is called -- users may apply them
    # explicitly when doing unconstrained parameter optimization, e.g.
    return dict(
        amplitude=parameter_properties.ParameterProperties(
            default_constraining_bijector_fn=(
                lambda: tfb.Softplus(low=dtype_util.eps(dtype)))),
        length_scale=parameter_properties.ParameterProperties(
            default_constraining_bijector_fn=(
                lambda: tfb.Softplus(low=dtype_util.eps(dtype)))))

  def _apply(self, x1, x2, example_ndims=0):
    del example_ndims  # Can ignore this arg.
    pass

Make sure this kernel gives the same output as ExponentiatedQuadratic in the TFP library.

my_kernel = MyExponentiatedQuadratic(amplitude=2., length_scale=0.5)
tfp_kernel = tfpk.ExponentiatedQuadratic(amplitude=2., length_scale=0.5)
np.testing.assert_allclose(my_kernel.matrix(x, y), tfp_kernel.matrix(x, y), rtol=1e-5)

Solution

class MyExponentiatedQuadratic(tfpk.AutoCompositeTensorPsdKernel):

  def __init__(self,
               amplitude,
               length_scale):
    self.amplitude = amplitude
    self.length_scale = length_scale
    super(MyExponentiatedQuadratic, self).__init__(
        feature_ndims=1,
        dtype=jnp.float32,
        name='MyExponentiatedQuadratic',
        validate_args=False)

  @classmethod
  def _parameter_properties(cls, dtype):
    return dict(
        amplitude=parameter_properties.ParameterProperties(
            default_constraining_bijector_fn=(
                lambda: tfb.Softplus(low=dtype_util.eps(dtype)))),
        length_scale=parameter_properties.ParameterProperties(
            default_constraining_bijector_fn=(
                lambda: tfb.Softplus(low=dtype_util.eps(dtype)))))

  def _apply(self, x1, x2, example_ndims=0):
    del example_ndims
    pairwise_sq_distance = jnp.sum((x1 - x2)**2, axis=-1)
    return jnp.exp(-0.5 * pairwise_sq_distance / self.length_scale ** 2) * self.amplitude ** 2