Multivariate observed - dimensions get mixed up in the posterior predictive?

First of all, thank you for this amazing library!

I have noticed a behavior that I cannot explain.
I’m trying to build a hierarchical Poisson regression and the PPC that I’m getting has the right shape, but the dimensions are mixed up (I can see different series mixed in within the same dimension).

Is this an expected behavior and should I always realign ppc myself?


Minimum example:
Let’s have 3 dimensional observed each with very different rate parameter (10,50,100)

# generate data
size=(10,)
rate_list=[10,50,100]
y_train=np.vstack([stats.poisson.rvs(rate,size=size) for rate in rate_list])
y_train.shape # gets you (3,10)

A simple model where we identify parameters separately for each series (independent thanks to the “series” plate)

# simple model
def model_poi(obs=None):
    obs_dim=10
    series_dim=3
    with numpyro.plate('series',series_dim,dim=-2):
        alpha=numpyro.sample('alpha',dist.Normal(0,3))
    rate = numpyro.deterministic("rate",jnp.exp(alpha))
    with numpyro.plate("time", obs_dim, dim=-1):
        numpyro.sample("obs", dist.Poisson(rate), obs=obs)

Parameters are estimated with HMC and they are correct. PPC has the right shape as well

# get posterior predictive
predictive = Predictive(model_poi, posterior_samples,return_sites=["obs"],batch_ndims=1)
ppc = predictive(rng_key_, obs=None)
ppc['obs'].shape # gets you correct shape of (1000,3,10)

However, when you look at individual draws, you can see that the series are not represented row-wise anymore (middle dimension size = 3), you can see that each row goes series1,series2,series3, etc.

# show one draw
ppc['obs'][0]
DeviceArray([[  6,  68, 110,   6,  43, 112,   2,  63,  99,  12],
             [ 55, 100,  16,  58,  98,  13,  68,  97,  10,  57],
             [107,   9,  56,  91,  10,  59,  85,   6,  54,  92]],            dtype=int64)

My expectation would be that each series would be represented by a separate row like this:
DeviceArray([[ 6, 6, 2, 16, …],
[ 68, 43, 63, 55, …],
[110, 112, 99, 100, … ]], dtype=int64)

My question: Is this an expected behavior and should I always realign ppc myself?

Some additional observations:

  • I can get the expected behavior by setting Predictive(… ,batch_ndims=2), which gives me the same shape but the dimensions are aligned correctly (as per the documentation, I’d have expected to get a shape like (num_chains x N x ...) but instead I get (num_chains*N x ...). This aligns with my expectation after inspecting the implementation, it’s just not what I would expect as per the documentation)
  • This issue goes away when I have some covariates which force the right shape on the latent variables that Posterior() can then read (in dim=-1)

Yes, I think so. The output of predictive should have shape (num_samples, 3, 10) and ppc['obs'][0] should have shape (3, 10). I don’t see the difference w.r.t. your expectation. Could you pls elaborate on this point? Edit: oh, I see your point. The values are not arraged correctly. This should be a bug in Predictive as far as I see. Could you create a github issue for this? I’ll address this soon.

I’d have expected to get a shape like (num_chains x N x ...) but instead I get (num_chains*N x ...) .

If you get posterior grouped by chain mcmc.get_samples(group_by_chain=True) then your posterior samples will have extra batch shapes (num_chains x num_samples), and the expected output shape of Predictive(… ,batch_ndims=2) will be (num_chains x num_samples x 3 x 10).

So this is a bug in ExpandedDistribution:

import numpyro.distributions as dist
from jax import random
import numpy as np

dist.Poisson(np.array([[1], [10], [100]])).expand([3, 10]).sample(random.PRNGKey(0))

returns

DeviceArray([[  2,  12,  88,   1,  12,  93,   0,  10, 113,   3],
             [  9, 117,   1,  10,  95,   2,   8,  88,   1,  11],
             [ 99,   1,   8,  95,   0,  12, 112,   1,   9, 105]],            dtype=int32)

We will fix the bug and make a patch release shortly.

Wow, you’re amazing! That was super fast - thank you!


Out of curiosity - how does such a model handle the multivariate observed data (when running MCMC)?
I’ve tried to find the answer in the source code, but I get lost in messengers and trees.

I’m asking because I have tried to put a plate “series” (outer) around the final observed sampling statement (inside the inner “time” plate), as the batch dimensions are independent. However, I haven’t noticed any difference (even for larger, more complex model)

In other words, is there any difference in how this model would be handled:
(I suspect this is the more “correct” version if dim=-2 is independent, but I cannot find any practical difference)

def model_poi_v2(obs=None):
    obs_dim=10
    series_dim=3
    with numpyro.plate('series',series_dim,dim=-2):
        alpha=numpyro.sample('alpha',dist.Normal(0,3))
    rate = numpyro.deterministic("rate",jnp.exp(alpha))
    with numpyro.plate('series',series_dim,dim=-2):
        with numpyro.plate("time", obs_dim, dim=-1):
            numpyro.sample("obs", dist.Poisson(rate), obs=obs)

Yes, you are right that model_poi_v2 is a better version. You can rewrite it to

def model_poi_v2(obs=None):
    obs_dim=10
    series_dim=3
    with numpyro.plate('series',series_dim,dim=-2):
        alpha=numpyro.sample('alpha',dist.Normal(0,3))
        rate = numpyro.deterministic("rate",jnp.exp(alpha))
        with numpyro.plate("time", obs_dim, dim=-1):
            numpyro.sample("obs", dist.Poisson(rate), obs=obs)

If you make sure that array shapes are passed down correctly and your model does not have discrete latent variables (in this case, we require plate to enumerate the discrete latent values correctly), then you don’t need to use plate in MCMC. But it is a good practice to use plate. Personally, I found it very helpful to interpreter the dependency in complicated models. (FYI in master branch, you can use render_model to get a graphical version of your model).

In your case, alpha has shape (3, 1), so dist.Poisson(rate) will has shape (3, 1). Under time plate, it will be expanded to (3, 10). So you still get a correct shape using model_poi.

edit: the bug is fixed in this PR

That’s a great tip - I haven’t noticed that! Thank you.