Distributed Computation for Numpyro SVI

Dear Numpyro,

I am looking to solve a hierarchical modelling problem using SVI whereby conditional independence in the model is used to train the model in a distributed way.
For example, consider the simple model:

def model(x, y):
    beta_mu = numpyro.sample("beta_mu", Normal())
    beta_log_sigma = numpyro.sample("beta_sigma", Normal())
    beta_sigma = jnp.exp(beta_log_sigma)

    with numpyro.plate("dataset_plate", shape=x.shape[1]):
        betas = numpyro.sample("betas_std",
                               numpyro.distributions.Normal())
        betas = betas * beta_sigma + beta_mu
        numpyro.sample("pred", numpyro.distributions.Normal(x * betas), obs=y)

In my situation, the dataset_plate is very large and the sub-models that would be contained within a very heterogeneous, and I would like to leverage the conditional independence between elements in the dataset_plate to make model training feasible.

Is there a way in Numpyro to do something like:

  1. Construct a model of the hierarchical parameters
  2. Conditional on the hierarchical parameters, construct a model of the data-level parameters
  3. Optimisation with interleaving steps between updating 1) from the parameters of 2), and 2) from the parameters of 1)?

Thanks