Pairwise dense mass matrix

I have a high dimensional bayesian numpyro model, where I have two sets of parameters I’m trying to estimate “x” and “z”. Each are vectors of the same length ~10,000. I’m fitting using NUTS and MCMC.

I know that “x” and “z” are highly correlated, but only for matching pairs. That is, I know that x_1 is correlated with z_1, and x_2 with z_2, but not x_1 with z_2 or x_2 with z_1. Put another way, I know x_i is correlated with z_j for i==j only.

Because the dimensionality is so high, if I just specify dense_mass = [("x","z")] crashes due to insufficient memory.

So my question is, how do I tell NUTS to estimate only the matching pairs from “x” and “z” in the mass matrix?

Here is a basic version of my model:

def model(a,b,y):
    x = numpyro.sample("x",dist.Normal(0,1).expand([10000]))
    z = numpyro.sample("z",dist.Normal(0,1).expand([10000]))
    lambda_ = jnp.exp(x*a+z*b) + 1e-16

Currently, we don’t support this feature. I guess you can fit a MVN guide using SVI and precondition your variables on the covariance matrix learned by SVI. Something like

def model(...):
    xz = sample("xz", dist.Normal(0, 1).expand([10000, 2]).to_event())
    x, z = xz[:, 0], xz[:, 1]

def guide(...):
    loc = param("loc", ...)
    scale_tril = param("scale_tril", ...)
    xz = sample("xz", dist.MultivariateNormal(loc, scale_tril=scale_tril))