Potential mistakes in the implementation of sparse regression examples

I was going through the example on sparse regression, which is based on the following paper

Raj Agrawal, Jonathan H. Huggins, Brian Trippe, Tamara Broderick (2019), “The Kernel Interaction Trick: Fast Bayesian Discovery of Pairwise Interactions in High Dimensions”, https://arxiv.org/abs/1905.06501.

Based on the equations presented in the paper, to me it seems that the kernel is wrongly implemented.
Current implementation of the kernel:

def dot(X, Z):
    return jnp.dot(X, Z[..., None])[..., 0]

# The kernel that corresponds to our quadratic regressor.
def kernel(X, Z, eta1, eta2, c, jitter=1.0e-4):
    eta1sq = jnp.square(eta1)
    eta2sq = jnp.square(eta2)
    k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(X, Z))
    k2 = -0.5 * eta2sq * dot(jnp.square(X), jnp.square(Z))
    k3 = (eta1sq - eta2sq) * dot(X, Z)
    k4 = jnp.square(c) - 0.5 * eta2sq
    if X.shape == Z.shape:
        k4 += jitter * jnp.eye(X.shape[0])
    return k1 + k2 + k3 + k4

Note that inside the model function X, Z variables correspond to kappa * X.
Given the proposition 6.1 on page 8 of the paper I would rewrite this as follows:

def kernel(X, Z, kappa, eta1, eta2, c, jitter=1.0e-4):
    eta1sq = jnp.square(eta1)
    eta2sq = jnp.square(eta2)
    k1 = 0.5 * eta2sq * jnp.square(1.0 + dot(kappa * X, kappa * Z))
    k2 = -0.5 * eta2sq * dot(kappa * jnp.square(X), kappa * jnp.square(Z))
    k3 = (eta1sq - eta2sq) * dot(kappa * X, kappa * Z)
    k4 = jnp.square(c) - 0.5 * eta2sq
    if X.shape == Z.shape:
        k4 += jitter * jnp.eye(X.shape[0])
    return k1 + k2 + k3 + k4

The original stan implementation linked to the paper is also slightly different

matrix[N, N] K1 = diag_post_multiply(X, kappa) * X’;
matrix[N, N] K2 = diag_post_multiply(X2, kappa) * X2’;
matrix[N, N] K = .5 * square(eta_2) * square(K1 + 1.0) + (square(alpha) - .5 * square(eta_2)) * K2 + (square(eta_1) - square(eta_2)) * K1 + square(c) - .5 * square(eta_2);

I am not familiar with Stan’s code notation, but in python my guess would be something like this (with alpha=0)

def kernel(X, Z, kappa, eta1, eta2, c, jitter=1.0e-4):
    eta1sq = jnp.square(eta1)
    eta2sq = jnp.square(eta2)
    K1 = dot(kappa * X, Z)
    K2 = dot(kappa * jnp.square(X),  jnp.square(Z))
    K4 = jnp.square(c) - 0.5 * eta2sq
    if X.shape == Z.shape:
        k4 += jitter * jnp.eye(X.shape[0])
    return .5 * eta2sq * jnp.square(1 + K1) - .5 * eta2sq * K2 + (eta1sq - eta2sq) * K1 + K4

Which is again different from my interpretation of propostion 6.1, but also different from Pyro/Numpyro example.

@pyroman It is a typo in the paper (see appendices C1, C2)

@fehiepsi If you are certain that C1 and C2 present the correct form of equations, than I will accept that :slight_smile: Then I am only confused why is the Stan implementation different, as this is what they used in the paper.

I think it is a bug in the implementation. Or the authors are using a different formula. I’m not sure. It is better to reach out to the authors for more information. :slight_smile:

cc @martinjankowiak

I wrote an email to the first author. Hopefully he will see it and comment here.

@pyroman as a general rule of thumb you should consider arxiv papers as pre-prints. often authors do not update the arxiv paper to reflect the final published version.

the formulae should be correct. it’s pretty easy to derive them from scratch (as i did).

Thanks @martinjankowiak, but the published version is the same as v2 on arxiv. For me the inconsistencies between the main text math, appendix math, Pyro/Numpyro implementations and Stan implementations are still rather concerning.

@pyroman Sorry for the delay, and thank you for your question. Proposition 6.1 is consistent with the Pyro code (and also consistent with Appendix C1 and C2). Explanation below:

Let f(X, Z, eta1, eta2, c) and g(X, Z, kappa , eta1, eta2, c ) denote the Pyro kernel code and your Pyro code from your first comment. Then, f(kX, kZ, eta1, eta2, c ) = g(X, Z, k, eta1, eta2, c ). In line 78 of the Pyro code, kappa is multiplied by X (kX). On line 81, the kernel takes kX as input. Hence, the Pyro code is the same as the kernel you specified.

The Stan code writes the kernel differently than the Pyro code. However, the resulting kernel is still algebraically the same (alpha in the Stan code corresponds to eta3 in Proposition 6.1). The inconsistencies in notation came from writing the code and paper at different points in time :slight_smile: Nevertheless, the Pyro and Stan code should output the same kernel matrices.

1 Like

Thanks @rajagrawal !