# SVI: Normal in model, MultivariateNormal in guide

In my model I have a plate over batch with dim=-1 and a plate over components of a normal with dim=-2:

``````def model():
with pyro.plate('mini_batch', size_mini_batch, dim=-1):
# sample other latents using mini_batch dim=-1
with pyro.plate('components',n_components,dim=-2):
dist = dist.Normal(loc,scale)
a = pyro.sample('a', dist, obs=obs)
return a
``````

The model returns produces samples of a with shape `[n_components,size_mini_batch]`.

Then in the guide I have:

``````  with pyro.plate('mini_batch', size_mini_batch, dim=-1):
lam = net(mini_batch) # mini_batch on dim 0.
lam_reshape = lam.transpose(0,-1) # switch mini_batch dim to -1. lam_reshape shape [2*n_components, size_mini_batch]
loc, log_scale = lam_reshape[:n_components], lam_reshape[n_components:2 * n_components]
scale = torch.exp(log_scale)

with pyro.plate('components', n_components, dim=-2):
dist = dist.Normal(loc, scale)
a = pyro.sample("a", dist, obs=obs)
return a
``````

I’d like to extend the `Normal` distribution in the guide to a `MultivariateNormal`, where the covariance matrix is no longer diagonal.

Usually in pure pytorch I’d put the batch dimension on the 0th dimension, and then sample like so:

`MultivariateNormal(mv_loc,mv_cov)`, with a `mv_loc.shape` of `[size_mini_batch,n_components]` and a `mv_cov.shape` of `[size_mini_batch,n_components,n_components]`. But the shape or the sample (`[size_mini_batch,n_components]`) doesn’t match what’s in the model (`[n_components,size_mini_batch]`)

What are my options here?

• What shape should `mv_loc` and `mv_cov` have so that it parallelizes over dim=-1? Is this possible?
• I’m fairly locked into the mini_batch` dim=-1` with how my model code is vectorized (lots’ not shown here). However, was this a bad design choice? Can I just step out of the mini_batch plate for sampling a, and use `dim=0` for it and keep everything else the same?

Hi @geoffwoollard,
The easiest option might be to change your model to avoid the `components` plate and instead use `.to_event()`. Modifying your current model you could:

``````def model_v2():
# shapes as in your original model
assert loc.shape == (n_components, size_mini_batch)
assert scale.shape == (n_components, size_mini_batch)

# now reshape
with pyro.plate("mini_batch", size_mini_batch, dim=-1):
a = pyro.sample("a.T", dist.Normal(loc.T, scale.T).to_event(1), obs=obs.T).T
return a
``````

Then you can use a simple multivariate guide

``````def guide_v2():
with pyro.plate("mini_batch", size_mini_batch, dim=-1):
lam = net(mini_batch) # mini_batch on dim 0.
assert lam.shape == (size_mini_batch, 2 * n_components)
loc, log_scale = lam.reshape(mini_batch, 2, n_components).unbind(-2)
scale_tril = log_scale.exp().diag_embed()
pyro.sample("a", dist.MultivariateNormal(loc, scale_tril=scale_tril)
``````

Alternatively if you really want to support the “components” plate, you can follow the pattern of Pyro’s `AutoContinuous` guides like `AutoMultivariateNormal`: sample an auxiliary joint random variable with `is_auxiliary=True`, then reshape and sample the user-facing variable from a `Delta` distribution:

``````def guide():
with pyro.plate('mini_batch', size_mini_batch, dim=-1):
lam = net(mini_batch) # mini_batch on dim 0.
assert lam.shape == (size_mini_batch, 2 * n_components)
loc, log_scale = lam.reshape(mini_batch, 2, n_components).unbind(-2)
scale_tril = log_scale.exp().diag_embed()
a_aux = pyro.sample(
"a_aux",
dist.MultivariateNormal(loc, scale_tril=scale_tril),
infer={"is_auxiliary": True},  # <--- declare an auxiliary guide variable
)
assert a_aux.shape == (size_mini_batch, n_components)
a = a_aux.T
with pyro.plate('components', n_components, dim=-2):
pyro.sample("a", dist.Delta(a))  # <--- sample from a delta
``````