In my model I have a plate over batch with dim=-1 and a plate over components of a normal with dim=-2:
def model(): with pyro.plate('mini_batch', size_mini_batch, dim=-1): # sample other latents using mini_batch dim=-1 with pyro.plate('components',n_components,dim=-2): dist = dist.Normal(loc,scale) a = pyro.sample('a', dist, obs=obs) return a
The model returns produces samples of a with shape
Then in the guide I have:
with pyro.plate('mini_batch', size_mini_batch, dim=-1): lam = net(mini_batch) # mini_batch on dim 0. lam_reshape = lam.transpose(0,-1) # switch mini_batch dim to -1. lam_reshape shape [2*n_components, size_mini_batch] loc, log_scale = lam_reshape[:n_components], lam_reshape[n_components:2 * n_components] scale = torch.exp(log_scale) with pyro.plate('components', n_components, dim=-2): dist = dist.Normal(loc, scale) a = pyro.sample("a", dist, obs=obs) return a
I’d like to extend the
Normal distribution in the guide to a
MultivariateNormal, where the covariance matrix is no longer diagonal.
Usually in pure pytorch I’d put the batch dimension on the 0th dimension, and then sample like so:
MultivariateNormal(mv_loc,mv_cov), with a
[size_mini_batch,n_components] and a
[size_mini_batch,n_components,n_components]. But the shape or the sample (
[size_mini_batch,n_components]) doesn’t match what’s in the model (
What are my options here?
- What shape should
mv_covhave so that it parallelizes over dim=-1? Is this possible?
- I’m fairly locked into the mini_batch
dim=-1with how my model code is vectorized (lots’ not shown here). However, was this a bad design choice? Can I just step out of the mini_batch plate for sampling a, and use
dim=0for it and keep everything else the same?