How do I setup nested plates with depth that depends on input shape

Hello devs

The nested plates for the following structure

with numpyro.plate("A1", features[0].shape[0], dim=-1):
    a_hyperprior = numpyro.sample("a", dist.HalfNormal(2))
                with numpyro.plate("An", features[n-1].shape[0], dim=-n):
                    a = numpyro.sample("a", dist.HalfNormal(a_hyperprior))

where n is inferred from input to the model, called, features, which is a 2d numpy array, and n = features.shape[0]

Is it possible to write code that sets up the above nested structure depending?

additionally, how would I go about accessing the variable a? right now, when n=2 is fixed, I access it via a[features[1], features[0]]… which translates to a[features[n-1], features[n-2], … features[0]].

You can use plate_stack for this.

how would I go about accessing the variable a?

I’m not sure if I interpreted your question correctly. I guess you can loop over the features for indexing.