PSD kernels
TFP’s PSD kernels compute positive semidefinite kernel functions. A PSD kernel instance is a required arg to TFP’s Gaussian Process distribution, so specifying a GP model coroutine will generally involve defining a PSD kernel as an intermediate.
PSD kernel subclasses take hyperparameters, such as amplitude and length scale, as constructor args. They have three primary public methods: apply
, matrix
, and tensor
, each of which computes the kernel function pairwise on inputs in different ways:
apply
computes the value of the kernel function at a pair of (batches of) input locations. It’s the only required method for subclasses:matrix
andtensor
are implemented in terms ofapply
(except when a more efficient method exists to compure pairwise kernel matrices).matrix
computes the value of the kernel pairwise on two (batches of) lists of input examples. When the two collections are the same the result is called the Gram matrix.matrix
is the most important method for GPs.tensor
generalizesmatrix
, taking rankk1
andk2
collections of input examples to a rankk1 + k2
collection of kernel values. (We mentiontensor
for completeness, but it isn’t relevant to GPs).
PSD kernels have somewhat complex shape semantics, due to the need to define which input dimensions should be included in pairwise computations and which should be treated as batch dimensions (denoting independent sets of input points.)
Imports
from jax import numpy as jnp
import numpy as np
from tensorflow_probability.python.internal import dtype_util
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfpk = tfp.math.psd_kernels
Some examples of PSD kernel usage:
# Construct a MaternFiveHalves kernel (with empty batch shape).
amplitude = 2.
length_scale = 0.5
k = tfpk.MaternFiveHalves(
amplitude=amplitude, length_scale=length_scale)
# Randomly sample some input data.
num_features = 5
num_observations = 12
x = np.random.normal(size=[num_observations, num_features])
# `matrix` computes pairwise kernel values for the Cartesian product over the
# second-to-rightmost dimension of the inputs. Following the terminology in the
# PSD kernel docstring, there is a single example dimension (and single feature
# dimension).
assert k.matrix(x, x).shape == (12, 12)
# Calling `matrix` on inputs of shape [12, d] and [10, d] results in a kernel
# matrix of shape (12, 10)
y = np.random.normal(size=[10, num_features])
assert k.matrix(x, y).shape == (12, 10)
ARD kernels in TFP are implemented with the FeatureScaled
kernel.
length_scale = np.random.uniform(size=[num_features])
ard_kernel = tfpk.FeatureScaled(
tfpk.MaternFiveHalves(amplitude=np.float64(0.3)),
scale_diag=length_scale)
Sums and products of PSD kernels are easy to compute, via operator overloading.
matern = tfpk.MaternFiveHalves(amplitude=2.)
squared_exponential = tfpk.ExponentiatedQuadratic(length_scale=0.1)
sum_kernel = matern + squared_exponential
np.testing.assert_allclose(
sum_kernel.matrix(x, x),
matern.matrix(x, x) + squared_exponential.matrix(x, x))
Exercise: Implemented a squared exponential kernel
As an exercise, try implementing a squared exponential PSD kernel:
k(x, y) = amplitude**2 * exp(-||x - y||**2 / (2 * length_scale**2))
In TFP library kernels (see TFP’s squared exponential kernel), there are other details to consider, like handling of different dtypes, accepting either length_scale
or inverse_length_scale
, and ensuring that kernel batch shapes broadcast correctly with inputs.
For the purpose of the exercise we can ignore these, and apply
can be written as a straightforward implementation of the kernel function. (New PSD kernels added to TFP would have to treat this more carefully, and existing kernels serve as good guides).
Try implementing _apply
below (the solution is a couple cells down).
class MyExponentiatedQuadratic(tfpk.AutoCompositeTensorPsdKernel):
def __init__(self,
amplitude,
length_scale):
self.amplitude = amplitude
self.length_scale = length_scale
super(MyExponentiatedQuadratic, self).__init__(
feature_ndims=1,
dtype=jnp.float32,
name='MyExponentiatedQuadratic',
validate_args=False)
@classmethod
def _parameter_properties(cls, dtype):
# All TFP objects have parameter properties, which contain information on
# the shape and domain of the parameters. The Softplus bijector is
# associated with both the amplitude and length scale parameters, and may be
# used to constrain these parameters to be positive. These bijectors are NOT
# automatically applied when the kernel is called -- users may apply them
# explicitly when doing unconstrained parameter optimization, e.g.
return dict(
amplitude=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: tfb.Softplus(low=dtype_util.eps(dtype)))),
length_scale=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: tfb.Softplus(low=dtype_util.eps(dtype)))))
def _apply(self, x1, x2, example_ndims=0):
del example_ndims # Can ignore this arg.
pass
Make sure this kernel gives the same output as ExponentiatedQuadratic
in the TFP library.
my_kernel = MyExponentiatedQuadratic(amplitude=2., length_scale=0.5)
tfp_kernel = tfpk.ExponentiatedQuadratic(amplitude=2., length_scale=0.5)
np.testing.assert_allclose(my_kernel.matrix(x, y), tfp_kernel.matrix(x, y), rtol=1e-5)
Solution
class MyExponentiatedQuadratic(tfpk.AutoCompositeTensorPsdKernel):
def __init__(self,
amplitude,
length_scale):
self.amplitude = amplitude
self.length_scale = length_scale
super(MyExponentiatedQuadratic, self).__init__(
feature_ndims=1,
dtype=jnp.float32,
name='MyExponentiatedQuadratic',
validate_args=False)
@classmethod
def _parameter_properties(cls, dtype):
return dict(
amplitude=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: tfb.Softplus(low=dtype_util.eps(dtype)))),
length_scale=parameter_properties.ParameterProperties(
default_constraining_bijector_fn=(
lambda: tfb.Softplus(low=dtype_util.eps(dtype)))))
def _apply(self, x1, x2, example_ndims=0):
del example_ndims
pairwise_sq_distance = jnp.sum((x1 - x2)**2, axis=-1)
return jnp.exp(-0.5 * pairwise_sq_distance / self.length_scale ** 2) * self.amplitude ** 2