Nested Plates have too many parameters

I’m still pretty new to Numpyro so forgive me if this question has been answered, but I haven’t been able to find it. I’m working on a hierarchical model with multiple levels. I’ve been using nested plates to achieve this, but the nested plates generate parameters for all pairwise combinations when I only need a subset of those combinations.

For example, if you have a problem with categories, and subcategories which define your hierarchy and each subcategory only maps to a single category, nested plates will still generate parameters for all pairwise combinations of subcategory and category, even if the pair doesn’t exist in your data. How can we avoid this?

categories = np.array(['a', 'b', 'c']).repeat(10)
subcategories = np.concatenate((np.array([1,2]).repeat(5),np.array([3,4]).repeat(5),np.array([5,6]).repeat(5)))
print(categories)

array([‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘b’, ‘b’, ‘b’,
‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘c’, ‘c’, ‘c’, ‘c’, ‘c’, ‘c’,
‘c’, ‘c’, ‘c’, ‘c’], dtype=‘<U1’)

print(subcategories)

array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5,
5, 5, 5, 6, 6, 6, 6, 6])

def model(cat=None, subcat=None, data=None):
    mu = numpyro.sample("mu", dist.Normal(0, 1))
    with numpyro.plate("cat", len(cat.unique())):
        mu_cat = numpyro.sample("obs", dist.Normal(mu, 1), obs=cat)
        with numpyro.plate("subcat", len(subcat.unique())):
            mu_subcat = numpyro.sample("obs", dist.Normal(mu_cat, 1))
    
    with numpyro.plate("data", len(data)):
        # rest of model here ....

The problem is that when the dimensions become large (which they are in my problem) you quickly run out of GPU memory. Any tips for dealing with this?

Thank you!

Welcome to Numpyro!

Can you maybe get away without nested plates? Instead you could pass category ids and subcategory ids and use regular plates.

Would something like this work?

import jax.numpy as jnp
import numpy as np
import numpyro
import jax

cat = np.array(["a", "a", "b", "b", "b", "b"])
cat_id = jnp.array([0, 0, 1, 1, 1, 1])
subcat = np.array(["a1", "a1", "b1", "b1", "b2", "b2"])
subcat_id = jnp.array([0, 0, 1, 1, 2, 2])

mu_cat = jnp.array([0.0, 1.0])
mu_subcat = np.array([0.1, 0.2, 0.3])

y = numpyro.distributions.Normal(
        loc = mu_cat[cat_id] + mu_subcat[subcat_id],
        scale = 1
    ).sample(jax.random.PRNGKey(0))

def model(num_cats, num_subcats, N, cat_id, subcat_id, y):

    with numpyro.plate("cat", num_cats):
        mu_cat = numpyro.sample("mu_cat", numpyro.distributions.Normal(0, 1))

    with numpyro.plate("subcat", num_subcats):
        mu_subcat = numpyro.sample("mu_subcat", numpyro.distributions.Normal(0, 1))

    with numpyro.plate("data", N):
        numpyro.sample(
            "y",
            numpyro.distributions.Normal(
                loc = mu_cat[cat_id] + mu_subcat[subcat_id],
                scale = 1),
            obs=y
            )

Thanks for the quick reply! I actually took this approach to begin with, but the model converged to 0 for all the cat level parameters and only used the subcat information. The problem is that I’m interested in extracting the category level information as well as the subcategory information. Ideally the subcategory distributions should be distributed around the mean of the category distributions and I believe that this formulation says that the two are independent when they really aren’t.

I guess the big idea here is that I’m trying to have the subcategory distributions shrink towards the category distributions which shrink towards the global.

I see. The problem you are describing seems to be with the model you want to estimate and not with estimation techinque (using nested vs unnested plates would make no difference). The model may not be identified.

Consider this. Let’s mu_subcat := mu_cat + delta_subcat, where delta_subcat is the subcategory’s deviation from category mean and assume that the data is drawn according to a distribution parametrized directly by mu_subcat. If the prior variance of delta_subcat will be very large (relative to the prior variance of mu_cat, then the model witll tend to attribute almost all variation to subcategories. Conversely, if the prior variance of delta_subcat is set to close to zero, then all subcaterories within a given category will look very similar.

It’s easy to see why this is a problem if you consider OLS. Could you run a regression of personal income on state FEs and county FEs? No, because of multicolinearity. It is possible to run a Bayesian version of this regression (thanks to priors), but the priors will determine what fraction of variation will be attributed to states vs counties.

Does this make sense?

1 Like

Yes it does, and thank you for pointing that out.

The results from the posterior predictive make sense but you’re right, because of multicolinearity the parameters don’t necessarily converge to what I’m expecting.

I’m wondering if it would make sense to first fit the model at the category level to obtain the higher level effects and then feed those into a subcategory model as priors. That should preserve the interpretability of the effects and give me the proper shinkage effect, no? Any problems with that approach?

I think the issue here is that the hyper-priors are getting biased by the redundant (extra) priors which are not reliable. Ideally, you’d want to look at the hyper-priors which are not biased by the extra priors (which you’re anyway not fitting to the observed data).

I believe a plausible fix to this problem is to mask those redundant priors. @martinjankowiak @fehiepsi Could you please check this thread and let us know what you think.

I believe if we can solve this issue, it will be extremely helpful in integrating Bayesian hypothesis testing to numpyro.