Dear Numpyro,
I am looking to solve a hierarchical modelling problem using SVI whereby conditional independence in the model is used to train the model in a distributed way.
For example, consider the simple model:
def model(x, y):
beta_mu = numpyro.sample("beta_mu", Normal())
beta_log_sigma = numpyro.sample("beta_sigma", Normal())
beta_sigma = jnp.exp(beta_log_sigma)
with numpyro.plate("dataset_plate", shape=x.shape[1]):
betas = numpyro.sample("betas_std",
numpyro.distributions.Normal())
betas = betas * beta_sigma + beta_mu
numpyro.sample("pred", numpyro.distributions.Normal(x * betas), obs=y)
In my situation, the dataset_plate
is very large and the sub-models that would be contained within a very heterogeneous, and I would like to leverage the conditional independence between elements in the dataset_plate
to make model training feasible.
Is there a way in Numpyro to do something like:
- Construct a model of the hierarchical parameters
- Conditional on the hierarchical parameters, construct a model of the data-level parameters
- Optimisation with interleaving steps between updating 1) from the parameters of 2), and 2) from the parameters of 1)?
Thanks