Pyro/NumPyro - SVI - how to create variable number of parameters?

The SVI tutorial (SVI Part I: An Introduction to Stochastic Variational Inference in Pyro — Pyro Tutorials 1.8.4 documentation) shows how to create two parameters specifically:

def guide(data):
    # register the two variational parameters with Pyro.
    alpha_q = pyro.param("alpha_q", torch.tensor(15.0),
                         constraint=constraints.positive)
    beta_q = pyro.param("beta_q", torch.tensor(15.0),
                        constraint=constraints.positive)
    # sample latent_fairness from the distribution Beta(alpha_q, beta_q)
    pyro.sample("latent_fairness", dist.Beta(alpha_q, beta_q))

How does one create a variable number of parameters? The tutorial (and its successors) doesn’t specify this very simple, common use case.

For instance, suppose I have N coins and I want to identify each coin’s probability of heads, with a Beta guide for each coin. I want 2N variational parameters for the N variational Beta distributions. How do I do this?

I think the solution can be found in another tutorial: Dirichlet Process Mixture Models in Pyro — Pyro Tutorials 1.8.4 documentation

The answer is something like the following:

kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2), 3 * torch.eye(2)).sample([T]))
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)

I do not know whether the lambdas are necessary.

the second argument to param is a PyTorch tensor. it can have any shape you like:

pyro.param("my_param", torch.ones(2, 3, 4))