Re pyro.infer.SVI
, in my model I’m doing
def model(...):
with pyro.panel('mini_batch', batch_size, dim=-1):
conc = tensor(1)
loc = tensor([1.,0,0,0])
unit_vectors = pyro.sample('result_encoding',dist.ProjectedNormal(conc*loc)) # shape batch_size,4
result = conversion_func_a(unit_vectors)
The shape of unit_vectors
does not have the batch_size
on dim=-1
, but is of shape (batch_size,4)
.
But it works out fine, because in the guide I’m also using a mixture of ProjectedNormal
s, and the return shape is the same: (batch_size,4)
.
However, I’m trying to have an alternative encoding for result
where the distribution in the guide comes from a gaussian mixture model (gmm)
For the gmm, my model has
def model(...):
with pyro.panel('mini_batch', batch_size, dim=-1):
with pyro.plate('d6', 6, dim=-2):
d6_dist = dist.Normal(0,1)
d6 = pyro.sample('result_encoding',d6_dist) # shape (6,batch_size)
result = conversion_func_b(d6)
Now notice that the batch_size
is indeed in dim=-1
.
I’m having a hard time matching this shape in the guide, using a gaussian mixture model with MixtureSameFamily
: Probability distributions - torch.distributions — PyTorch 1.13 documentation . The return shape of the sample statement in the guide is (batch_size,6)
, and doesn’t match the model, and an error is thrown.
I tried
- Guide: Changing the gmm in the guide, so that it returns
(6,batch_size)
. I can’t get this to work, asMixtureSameFamily
do not have enough flexible options to tell how the distributions in the mixture weights and the distributions should match - I’m forced into the conventions implied in the source - Model: Getting the return of
pyro.sample('result_encoding',d6_dist)
to be shape(batch_size,6)
(like in theProjectedNormal
). I could have to swap the dims in the plates like so:
def model(...):
with pyro.panel('mini_batch', batch_size, dim=-2):
with pyro.plate('d6', 6, dim=-1):
d6_dist = dist.Normal(0,1)
d6 = pyro.sample('result_encoding',d6_dist) # shape (batch_size,6)
result = conversion_func_b(d6)
(2. continued) but I’d like to avoid this because I am sharing code in the model between the ProjectedNormals and the d6/Normals and want the mini_batch
plate in the model to stay on dim=-1
.
Any suggestions?