Hello devs
The nested plates for the following structure
with numpyro.plate("A1", features[0].shape[0], dim=-1):
a_hyperprior = numpyro.sample("a", dist.HalfNormal(2))
.
.
.
with numpyro.plate("An", features[n-1].shape[0], dim=-n):
a = numpyro.sample("a", dist.HalfNormal(a_hyperprior))
where n
is inferred from input to the model, called, features
, which is a 2d numpy array, and n = features.shape[0]
Is it possible to write code that sets up the above nested structure depending?
additionally, how would I go about accessing the variable a? right now, when n=2
is fixed, I access it via a[features[1], features[0]]… which translates to a[features[n-1], features[n-2], … features[0]].