In my model I have a plate over batch with dim=-1 and a plate over components of a normal with dim=-2:
def model():
with pyro.plate('mini_batch', size_mini_batch, dim=-1):
# sample other latents using mini_batch dim=-1
with pyro.plate('components',n_components,dim=-2):
dist = dist.Normal(loc,scale)
a = pyro.sample('a', dist, obs=obs)
return a
The model returns produces samples of a with shape [n_components,size_mini_batch]
.
Then in the guide I have:
with pyro.plate('mini_batch', size_mini_batch, dim=-1):
lam = net(mini_batch) # mini_batch on dim 0.
lam_reshape = lam.transpose(0,-1) # switch mini_batch dim to -1. lam_reshape shape [2*n_components, size_mini_batch]
loc, log_scale = lam_reshape[:n_components], lam_reshape[n_components:2 * n_components]
scale = torch.exp(log_scale)
with pyro.plate('components', n_components, dim=-2):
dist = dist.Normal(loc, scale)
a = pyro.sample("a", dist, obs=obs)
return a
I’d like to extend the Normal
distribution in the guide to a MultivariateNormal
, where the covariance matrix is no longer diagonal.
Usually in pure pytorch I’d put the batch dimension on the 0th dimension, and then sample like so:
MultivariateNormal(mv_loc,mv_cov)
, with a mv_loc.shape
of [size_mini_batch,n_components]
and a mv_cov.shape
of [size_mini_batch,n_components,n_components]
. But the shape or the sample ([size_mini_batch,n_components]
) doesn’t match what’s in the model ([n_components,size_mini_batch]
)
What are my options here?
- What shape should
mv_loc
andmv_cov
have so that it parallelizes over dim=-1? Is this possible? - I’m fairly locked into the mini_batch
dim=-1
with how my model code is vectorized (lots’ not shown here). However, was this a bad design choice? Can I just step out of the mini_batch plate for sampling a, and usedim=0
for it and keep everything else the same?