I’m still pretty new to Numpyro so forgive me if this question has been answered, but I haven’t been able to find it. I’m working on a hierarchical model with multiple levels. I’ve been using nested plates to achieve this, but the nested plates generate parameters for all pairwise combinations when I only need a subset of those combinations.

For example, if you have a problem with categories, and subcategories which define your hierarchy and each subcategory only maps to a single category, nested plates will still generate parameters for all pairwise combinations of subcategory and category, even if the pair doesn’t exist in your data. How can we avoid this?

```
categories = np.array(['a', 'b', 'c']).repeat(10)
subcategories = np.concatenate((np.array([1,2]).repeat(5),np.array([3,4]).repeat(5),np.array([5,6]).repeat(5)))
print(categories)
```

array([‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘b’, ‘b’, ‘b’,

‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘c’, ‘c’, ‘c’, ‘c’, ‘c’, ‘c’,

‘c’, ‘c’, ‘c’, ‘c’], dtype=’<U1’)

```
print(subcategories)
```

array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5,

5, 5, 5, 6, 6, 6, 6, 6])

```
def model(cat=None, subcat=None, data=None):
mu = numpyro.sample("mu", dist.Normal(0, 1))
with numpyro.plate("cat", len(cat.unique())):
mu_cat = numpyro.sample("obs", dist.Normal(mu, 1), obs=cat)
with numpyro.plate("subcat", len(subcat.unique())):
mu_subcat = numpyro.sample("obs", dist.Normal(mu_cat, 1))
with numpyro.plate("data", len(data)):
# rest of model here ....
```

The problem is that when the dimensions become large (which they are in my problem) you quickly run out of GPU memory. Any tips for dealing with this?

Thank you!