I was going through the example on sparse regression, which is based on the following paper
Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019), “The Kernel Interaction Trick: Fast Bayesian Discovery of Pairwise Interactions in High Dimensions”, https://arxiv.org/abs/1905.06501.
Based on the equations presented in the paper, to me it seems that the kernel is wrongly implemented.
Current implementation of the kernel:
def dot(X, Z):
return jnp.dot(X, Z[..., None])[..., 0]
# The kernel that corresponds to our quadratic regressor.
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
k3 = (eta1sq - eta2sq) * dot(X, Z)
k4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape[0])
return k1 + k2 + k3 + k4
Note that inside the model function X, Z variables correspond to kappa * X.
Given the proposition 6.1 on page 8 of the paper I would rewrite this as follows:
def kernel(X, Z, kappa, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(kappa * X, kappa * Z))
k2 = -0.5 * eta2sq * dot(kappa * jnp.square(X), kappa * jnp.square(Z))
k3 = (eta1sq - eta2sq) * dot(kappa * X, kappa * Z)
k4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape[0])
return k1 + k2 + k3 + k4