I’m still pretty new to Numpyro so forgive me if this question has been answered, but I haven’t been able to find it. I’m working on a hierarchical model with multiple levels. I’ve been using nested plates to achieve this, but the nested plates generate parameters for all pairwise combinations when I only need a subset of those combinations.
For example, if you have a problem with categories, and subcategories which define your hierarchy and each subcategory only maps to a single category, nested plates will still generate parameters for all pairwise combinations of subcategory and category, even if the pair doesn’t exist in your data. How can we avoid this?
categories = np.array(['a', 'b', 'c']).repeat(10)
subcategories = np.concatenate((np.array([1,2]).repeat(5),np.array([3,4]).repeat(5),np.array([5,6]).repeat(5)))
print(categories)
array([‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘a’, ‘b’, ‘b’, ‘b’,
‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘b’, ‘c’, ‘c’, ‘c’, ‘c’, ‘c’, ‘c’,
‘c’, ‘c’, ‘c’, ‘c’], dtype=‘<U1’)
print(subcategories)
array([1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 5, 5,
5, 5, 5, 6, 6, 6, 6, 6])
def model(cat=None, subcat=None, data=None):
mu = numpyro.sample("mu", dist.Normal(0, 1))
with numpyro.plate("cat", len(cat.unique())):
mu_cat = numpyro.sample("obs", dist.Normal(mu, 1), obs=cat)
with numpyro.plate("subcat", len(subcat.unique())):
mu_subcat = numpyro.sample("obs", dist.Normal(mu_cat, 1))
with numpyro.plate("data", len(data)):
# rest of model here ....
The problem is that when the dimensions become large (which they are in my problem) you quickly run out of GPU memory. Any tips for dealing with this?
Thank you!