Difficulty translating Stan model with custom likelihood into NumPyro

I’m have difficulty using Numpyro to apply NUTS to the following generative model:

\theta:=(\sigma_\beta , \sigma_{X\beta}, \sigma_e) \stackrel{iid}{\sim}\mathrm{priors}
\beta\vert\theta \stackrel{iid}{\sim} N(0, m^{-1/2}\sigma_\beta)
X\vert\beta,\theta \stackrel{iid}{\sim} MVN(0_m, I + a_\theta \beta \beta^T)
y\vert X,\beta,\theta \sim MVN(X\beta, \sigma^2_eI_n)

where a_\theta:=(\sigma^2_{X\beta} - \sigma^2_\beta)\sigma^{-4}_\beta.

Though it would be possible to implement this using built-in distributions, this involves putting a likelihood on the matrix-valued random variable X. However, the (unnormalized) joint likelihood of X,\beta\vert\theta can be efficiently computed as
\ell_{X,\beta\vert\theta}\propto -\frac{1}{2} \left(m \beta^T\beta/\sigma_\beta^2 - \frac{a_\theta \beta^T X^T X \beta}{(1+a_\theta\beta^T\beta)} + n\log(1+a\beta)+m\log(\sigma_\beta^2 / m)\right).

I can implement this model and sample from it efficiently using Stan:

functions {
  real X_beta_lpdf(vector b, real s_beta, real s_Xbeta, matrix XtX, int M, int N) {
    real btb = dot_self(b);               // b.T @ b
    real btXtXtb = quad_form(XtX, b);     // b.T @ X.T @ X @ b
    real a = (s_Xbeta^2 - s_beta^2) / s_beta^4;
    real trace_terms = M*btb/s_beta^2 - a*btXtXtb/(1+a*btb);
    real log_dets = N*log(1 + a*btb) + M*log(s_beta^2 / M);
    return -.5*(trace_terms + log_dets);
  }
}

data {
  int<lower=1> N;         // dimension of y
  int<lower=1> M;         // dimension of beta
  vector[N] y;            // y (outcome variable)
  matrix[N,M] X;          // matrix of predictors
  matrix[M,M] XtX;        // precomputed X.T @ X
}
parameters {
  real<lower=0> s_beta;   // standard deviation of beta*m
  real<lower=0> s_Xbeta;  // standard deviation of X beta
  real<lower=0> s_e;      // standard deviation of residual variance
  vector[M] beta;         // vector of latent effects
}
transformed parameters {
  // variance components
  real v_beta = s_beta^2;
  real v_Xbeta = s_Xbeta^2;
  real v_e = s_e^2;
  // linear predictor
  /* vector[N] Xb = X*beta;  // X @ beta */
}

model {
  // priors
  s_beta ~ cauchy(0,1);
  s_Xbeta ~ cauchy(0,1);
  s_e ~ cauchy(0,1);
  // joint likelihood on X, beta
  target += X_beta_lpdf(beta | s_beta, s_Xbeta, XtX, M, N);
  // residual likelihood
  /* target += normal_lpdf(y | Xb, s_e); */
  target += normal_id_glm_lpdf(y | X, 0, beta, s_e);
}

This model works well and produces accurate estimates of \theta when applied to toy data:

import numpy as np

m=200
n=2000

## true parameter values
v_beta=.5
v_Xbeta=.6
v_e=.5

## generate toy data
a = (v_Xbeta - v_beta) / v_beta**2
beta = np.random.randn(m).reshape((m,1))*np.sqrt(v_beta/m)
X = np.random.multivariate_normal(np.zeros(m), np.eye(m) + a*np.outer(beta,beta), n)
e = np.random.randn(n).reshape((n,1))*np.sqrt(v_e)

y = X @ beta + e

## what we want to estimate:
## realized values of v_beta, v_Xbeta, v_e
m*np.var(beta), np.var(X@beta), np.var(e)

However, my attempts at translating this into NumPyro haven’t gone so well:

import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
numpyro.enable_x64()

class joint_X_b(dist.Distribution):
    arg_constraints = {
        "s_beta": dist.constraints.positive,
        "s_Xbeta": dist.constraints.positive,
    }
    support = dist.constraints.real_vector

    def __init__(self, s_beta=None, s_Xbeta=None, XtX=None, M=None, N=None):
        assert XtX.shape[0] == M
        self.s_beta=s_beta
        self.s_Xbeta=s_Xbeta
        self.XtX=XtX
        self.M=M
        self.N=N
        batch_shape = (jnp.shape(XtX)[0],1)
        event_shape = ()
        super().__init__(
            batch_shape=batch_shape,
            event_shape=event_shape,
            validate_args=None
        )
    
    def sample(self, key, sample_shape=()):
        raise NotImplementedError
    
    def log_prob(self, value):
        assert value.shape[0] == self.M
        btb = jnp.dot(value.T,value)
        btXtXb = jnp.dot(value.T,jnp.dot(self.XtX, value))
        a = (self.s_Xbeta**2-self.s_beta**2)/self.s_beta**4
        N=self.N; M=self.M; s_beta=self.s_beta
        trace_terms = M*btb/s_beta**2 - a*btXtXb/(1+a*btb)
        log_dets = N*jnp.log(1 + a*btb) + M*jnp.log(s_beta**2)-M*jnp.log(M)
        normalizing = (M+1)*N*jnp.log(2*jnp.pi)
        return -.5 * (trace_terms + log_dets  + normalizing)       

def npy_model(y=None, X=None, XtX=None, M=None, N=None):
    s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
    s_Xbeta = numpyro.sample('s_Xbeta', dist.HalfCauchy(1))
    s_e = numpyro.sample('s_e', dist.HalfCauchy(1))
    v_beta = numpyro.deterministic('v_beta', s_beta**2)
    v_Xbeta= numpyro.deterministic('v_Xbeta', s_Xbeta**2)
    v_e = numpyro.deterministic('v_e', s_e**2)
    beta = numpyro.sample('beta', joint_X_b(s_beta, s_Xbeta, XtX, M, N))
    Xb = jnp.dot(X, beta) 
    numpyro.sample('y', dist.Normal(Xb,jnp.ones(N)*s_e),obs=y)


nuts_kernel = NUTS(npy_model)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, y=y-np.mean(y), X=X, XtX = X.T @ X, M=X.shape[1], N=X.shape[0])

This yields severe overestimates of \sigma^2_\beta and \sigma^2_{X\beta} (e.g., >1 vs ~.5) and underestimates \sigma^2_e by a factor of 2. Further, the number of effective samples of \sigma^2_\beta and \sigma^2_{X\beta} hover around 10 - 30 after 1000 MCMC iterations. I don’t have these problems with my Stan implementation so I suspect I’m doing something wrong here.

Is there something flawed about my understanding of how to implement a custom distribution in NumPyro? Or is Stan just doing some automatic tuning that makes a huge difference? Thanks.

I believe this simply came down to an error with my use of the built in dist.Normal.
Substituting numpyro.sample('y', dist.Normal(0.,s_e),obs=y-Xb) for numpyro.sample('y', dist.Normal(Xb,jnp.ones(N)*s_e),obs=y) appears to have been the solution.

for future reference you could have used a factor statement, which is the numpyro equivalent of directly adding to the accumulating log density in stan, and is probably simpler

1 Like

Good to know, thanks