 Potential mistakes in the implementation of sparse regression examples

I was going through the example on sparse regression, which is based on the following paper

Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019), “The Kernel Interaction Trick: Fast Bayesian Discovery of Pairwise Interactions in High Dimensions”, https://arxiv.org/abs/1905.06501.

Based on the equations presented in the paper, to me it seems that the kernel is wrongly implemented.
Current implementation of the kernel:

def dot(X, Z):
return jnp.dot(X, Z[..., None])[..., 0]

# The kernel that corresponds to our quadratic regressor.
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
k3 = (eta1sq - eta2sq) * dot(X, Z)
k4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape)
return k1 + k2 + k3 + k4

Note that inside the model function X, Z variables correspond to kappa * X.
Given the proposition 6.1 on page 8 of the paper I would rewrite this as follows:

def kernel(X, Z, kappa, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(kappa * X, kappa * Z))
k2 = -0.5 * eta2sq * dot(kappa * jnp.square(X), kappa * jnp.square(Z))
k3 = (eta1sq - eta2sq) * dot(kappa * X, kappa * Z)
k4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape)
return k1 + k2 + k3 + k4

The original stan implementation linked to the paper is also slightly different

matrix[N, N] K1 = diag_post_multiply(X, kappa) * X’;
matrix[N, N] K2 = diag_post_multiply(X2, kappa) * X2’;
matrix[N, N] K = .5 * square(eta_2) * square(K1 + 1.0) + (square(alpha) - .5 * square(eta_2)) * K2 + (square(eta_1) - square(eta_2)) * K1 + square(c) - .5 * square(eta_2);

I am not familiar with Stan’s code notation, but in python my guess would be something like this (with alpha=0)

def kernel(X, Z, kappa, eta1, eta2, c, jitter=1.0e-4):
eta1sq = jnp.square(eta1)
eta2sq = jnp.square(eta2)
K1 = dot(kappa * X, Z)
K2 = dot(kappa * jnp.square(X),  jnp.square(Z))
K4 = jnp.square(c) - 0.5 * eta2sq
if X.shape == Z.shape:
k4 += jitter * jnp.eye(X.shape)
return .5 * eta2sq * jnp.square(1 + K1) - .5 * eta2sq * K2 + (eta1sq - eta2sq) * K1 + K4

Which is again different from my interpretation of propostion 6.1, but also different from Pyro/Numpyro example.

@pyroman It is a typo in the paper (see appendices C1, C2)

@fehiepsi If you are certain that C1 and C2 present the correct form of equations, than I will accept that Then I am only confused why is the Stan implementation different, as this is what they used in the paper.

I think it is a bug in the implementation. Or the authors are using a different formula. I’m not sure. It is better to reach out to the authors for more information. I wrote an email to the first author. Hopefully he will see it and comment here.

@pyroman as a general rule of thumb you should consider arxiv papers as pre-prints. often authors do not update the arxiv paper to reflect the final published version.

the formulae should be correct. it’s pretty easy to derive them from scratch (as i did).

Thanks @martinjankowiak, but the published version is the same as v2 on arxiv. For me the inconsistencies between the main text math, appendix math, Pyro/Numpyro implementations and Stan implementations are still rather concerning.