# Difficulty translating Stan model with custom likelihood into NumPyro

I’m have difficulty using Numpyro to apply NUTS to the following generative model:

\theta:=(\sigma_\beta , \sigma_{X\beta}, \sigma_e) \stackrel{iid}{\sim}\mathrm{priors}
\beta\vert\theta \stackrel{iid}{\sim} N(0, m^{-1/2}\sigma_\beta)
X\vert\beta,\theta \stackrel{iid}{\sim} MVN(0_m, I + a_\theta \beta \beta^T)
y\vert X,\beta,\theta \sim MVN(X\beta, \sigma^2_eI_n)

where a_\theta:=(\sigma^2_{X\beta} - \sigma^2_\beta)\sigma^{-4}_\beta.

Though it would be possible to implement this using built-in distributions, this involves putting a likelihood on the matrix-valued random variable X. However, the (unnormalized) joint likelihood of X,\beta\vert\theta can be efficiently computed as
\ell_{X,\beta\vert\theta}\propto -\frac{1}{2} \left(m \beta^T\beta/\sigma_\beta^2 - \frac{a_\theta \beta^T X^T X \beta}{(1+a_\theta\beta^T\beta)} + n\log(1+a\beta)+m\log(\sigma_\beta^2 / m)\right).

I can implement this model and sample from it efficiently using Stan:

functions {
real X_beta_lpdf(vector b, real s_beta, real s_Xbeta, matrix XtX, int M, int N) {
real btb = dot_self(b);               // b.T @ b
real btXtXtb = quad_form(XtX, b);     // b.T @ X.T @ X @ b
real a = (s_Xbeta^2 - s_beta^2) / s_beta^4;
real trace_terms = M*btb/s_beta^2 - a*btXtXtb/(1+a*btb);
real log_dets = N*log(1 + a*btb) + M*log(s_beta^2 / M);
return -.5*(trace_terms + log_dets);
}
}

data {
int<lower=1> N;         // dimension of y
int<lower=1> M;         // dimension of beta
vector[N] y;            // y (outcome variable)
matrix[N,M] X;          // matrix of predictors
matrix[M,M] XtX;        // precomputed X.T @ X
}
parameters {
real<lower=0> s_beta;   // standard deviation of beta*m
real<lower=0> s_Xbeta;  // standard deviation of X beta
real<lower=0> s_e;      // standard deviation of residual variance
vector[M] beta;         // vector of latent effects
}
transformed parameters {
// variance components
real v_beta = s_beta^2;
real v_Xbeta = s_Xbeta^2;
real v_e = s_e^2;
// linear predictor
/* vector[N] Xb = X*beta;  // X @ beta */
}

model {
// priors
s_beta ~ cauchy(0,1);
s_Xbeta ~ cauchy(0,1);
s_e ~ cauchy(0,1);
// joint likelihood on X, beta
target += X_beta_lpdf(beta | s_beta, s_Xbeta, XtX, M, N);
// residual likelihood
/* target += normal_lpdf(y | Xb, s_e); */
target += normal_id_glm_lpdf(y | X, 0, beta, s_e);
}

This model works well and produces accurate estimates of \theta when applied to toy data:

import numpy as np

m=200
n=2000

## true parameter values
v_beta=.5
v_Xbeta=.6
v_e=.5

## generate toy data
a = (v_Xbeta - v_beta) / v_beta**2
beta = np.random.randn(m).reshape((m,1))*np.sqrt(v_beta/m)
X = np.random.multivariate_normal(np.zeros(m), np.eye(m) + a*np.outer(beta,beta), n)
e = np.random.randn(n).reshape((n,1))*np.sqrt(v_e)

y = X @ beta + e

## what we want to estimate:
## realized values of v_beta, v_Xbeta, v_e
m*np.var(beta), np.var(X@beta), np.var(e)

However, my attempts at translating this into NumPyro haven’t gone so well:

import jax
import numpyro
import jax.random as random
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
import jax.numpy as jnp
numpyro.enable_x64()

class joint_X_b(dist.Distribution):
arg_constraints = {
"s_beta": dist.constraints.positive,
"s_Xbeta": dist.constraints.positive,
}
support = dist.constraints.real_vector

def __init__(self, s_beta=None, s_Xbeta=None, XtX=None, M=None, N=None):
assert XtX.shape[0] == M
self.s_beta=s_beta
self.s_Xbeta=s_Xbeta
self.XtX=XtX
self.M=M
self.N=N
batch_shape = (jnp.shape(XtX)[0],1)
event_shape = ()
super().__init__(
batch_shape=batch_shape,
event_shape=event_shape,
validate_args=None
)

def sample(self, key, sample_shape=()):
raise NotImplementedError

def log_prob(self, value):
assert value.shape[0] == self.M
btb = jnp.dot(value.T,value)
btXtXb = jnp.dot(value.T,jnp.dot(self.XtX, value))
a = (self.s_Xbeta**2-self.s_beta**2)/self.s_beta**4
N=self.N; M=self.M; s_beta=self.s_beta
trace_terms = M*btb/s_beta**2 - a*btXtXb/(1+a*btb)
log_dets = N*jnp.log(1 + a*btb) + M*jnp.log(s_beta**2)-M*jnp.log(M)
normalizing = (M+1)*N*jnp.log(2*jnp.pi)
return -.5 * (trace_terms + log_dets  + normalizing)

def npy_model(y=None, X=None, XtX=None, M=None, N=None):
s_beta = numpyro.sample('s_beta', dist.HalfCauchy(1))
s_Xbeta = numpyro.sample('s_Xbeta', dist.HalfCauchy(1))
s_e = numpyro.sample('s_e', dist.HalfCauchy(1))
v_beta = numpyro.deterministic('v_beta', s_beta**2)
v_Xbeta= numpyro.deterministic('v_Xbeta', s_Xbeta**2)
v_e = numpyro.deterministic('v_e', s_e**2)
beta = numpyro.sample('beta', joint_X_b(s_beta, s_Xbeta, XtX, M, N))
Xb = jnp.dot(X, beta)
numpyro.sample('y', dist.Normal(Xb,jnp.ones(N)*s_e),obs=y)

nuts_kernel = NUTS(npy_model)
mcmc = MCMC(nuts_kernel, num_warmup=1000, num_samples=1000)
rng_key = random.PRNGKey(0)

mcmc.run(rng_key, y=y-np.mean(y), X=X, XtX = X.T @ X, M=X.shape[1], N=X.shape[0])

This yields severe overestimates of \sigma^2_\beta and \sigma^2_{X\beta} (e.g., >1 vs ~.5) and underestimates \sigma^2_e by a factor of 2. Further, the number of effective samples of \sigma^2_\beta and \sigma^2_{X\beta} hover around 10 - 30 after 1000 MCMC iterations. I don’t have these problems with my Stan implementation so I suspect I’m doing something wrong here.

Is there something flawed about my understanding of how to implement a custom distribution in NumPyro? Or is Stan just doing some automatic tuning that makes a huge difference? Thanks.

I believe this simply came down to an error with my use of the built in dist.Normal.
Substituting numpyro.sample('y', dist.Normal(0.,s_e),obs=y-Xb) for numpyro.sample('y', dist.Normal(Xb,jnp.ones(N)*s_e),obs=y) appears to have been the solution.

for future reference you could have used a factor statement, which is the numpyro equivalent of directly adding to the accumulating log density in stan, and is probably simpler

1 Like

Good to know, thanks