Matching sample shapes in model and guide: ProjectedNormal vs MixtureSameFamily

Re pyro.infer.SVI, in my model I’m doing

def model(...):
  with pyro.panel('mini_batch', batch_size, dim=-1):
    conc = tensor(1)
    loc = tensor([1.,0,0,0])
    unit_vectors = pyro.sample('result_encoding',dist.ProjectedNormal(conc*loc)) # shape batch_size,4
    result = conversion_func_a(unit_vectors)

The shape of unit_vectors does not have the batch_size on dim=-1, but is of shape (batch_size,4).

But it works out fine, because in the guide I’m also using a mixture of ProjectedNormals, and the return shape is the same: (batch_size,4).

However, I’m trying to have an alternative encoding for result where the distribution in the guide comes from a gaussian mixture model (gmm)

For the gmm, my model has

def model(...):
  with pyro.panel('mini_batch', batch_size, dim=-1):
    with pyro.plate('d6', 6, dim=-2):
      d6_dist = dist.Normal(0,1)          
      d6 = pyro.sample('result_encoding',d6_dist)  # shape (6,batch_size)
      result = conversion_func_b(d6)

Now notice that the batch_size is indeed in dim=-1.

I’m having a hard time matching this shape in the guide, using a gaussian mixture model with MixtureSameFamily: Probability distributions - torch.distributions — PyTorch 1.13 documentation . The return shape of the sample statement in the guide is (batch_size,6), and doesn’t match the model, and an error is thrown.

I tried

  1. Guide: Changing the gmm in the guide, so that it returns (6,batch_size). I can’t get this to work, as MixtureSameFamily do not have enough flexible options to tell how the distributions in the mixture weights and the distributions should match - I’m forced into the conventions implied in the source
  2. Model: Getting the return of pyro.sample('result_encoding',d6_dist) to be shape (batch_size,6) (like in the ProjectedNormal). I could have to swap the dims in the plates like so:
def model(...):
  with pyro.panel('mini_batch', batch_size, dim=-2):
    with pyro.plate('d6', 6, dim=-1):
      d6_dist = dist.Normal(0,1)          
      d6 = pyro.sample('result_encoding',d6_dist)  # shape (batch_size,6)
      result = conversion_func_b(d6)

(2. continued) but I’d like to avoid this because I am sharing code in the model between the ProjectedNormals and the d6/Normals and want the mini_batch plate in the model to stay on dim=-1.

Any suggestions?

i don’t really understand. what is the purpose of this in the model?

That’s for the 6 parameters (R6) to each be independently Gaussian sampled. I’m using pytorch3d.transforms.rotation_6d_to_matrix to convert R6 to a 3x3 rotation matrix. I’m using converting the unit 4-vector from the ProjectedNormal with pytorch3d.transforms.quaternion_to_matrix.

Hi @geoffwoollard, I’d recommend studying the tensor shapes tutorial in detail. In particular there is more to distribution shapes than the tuple (batch_size, 6): you also need to get the batch | event boundary correct. I believe you’ll want, in both the model and guide, to make batch_shape = (batch_size,) and event_shape = (6,). To do that you’ll want to use .to_event(1) rather than pyro.plate("d6", 6, dim=...). The difference between pyro.plate and .to_event() is described in the tensor shapes tutorial.