Problem implementing a mixture model using enumeration

Hello,

I would like write a mixture model in pyro using parallel enumeration, in which one of my plates is not directly represented in the observed data. However, when computing svi.loss, I get an error that I cannot explain: “ValueError: Error while packing tensors”. I think is either related to the plate not represented in the data or the dimension pyro wishes to use for parallelization.

Below I show the code for a simplified version of the model as well as the lines to create a model/guide and reproduce the error I receive:


import pyro
import torch
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
from pyro.infer.autoguide import AutoDelta
from pyro import poutine
import pyro.distributions as dist

def check_model(model, *args):
    pyro.clear_param_store()
    trace = poutine.trace(model).get_trace(*args)
    print(trace.format_shapes())
    return pyro.render_model(model, model_args=args)

@config_enumerate
def full_model(data):
    plate1 = pyro.plate("plate1", 5, dim=-1)
    plate2 = pyro.plate("plate2", data.shape[-2], dim=-2)
    plate3 = pyro.plate("plate3", data.shape[-3], dim=-3)
    
    with plate3:
        A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]), 
                                         torch.tensor([1., 0.8, 0.6, 0.4, 0.2])))
        probs = pyro.sample('probs', dist.Dirichlet(torch.tensor([30., 70.])))
        assignment = pyro.sample('assignment', dist.Categorical(probs), infer={"enumerate": "parallel"})           
    
    with plate2:
        B = pyro.sample("B", dist.Uniform(0., 1.))
    
    prediction = torch.where(assignment==1, 
                             (A*B).sum(-1).unsqueeze(-1), 
                             (A*-B).sum(-1).unsqueeze(-1))
    
    with plate3, plate2:
        X = pyro.sample("X", dist.Gamma(torch.exp(prediction),
                                        torch.tensor(1.)), obs=data)

my_data = torch.ones((100, 10000, 1))
check_model(full_model, my_data)
optim = pyro.optim.Adam({'lr': 0.5, 'lrd': 0.01, 'betas': (0.80, 0.99)})
pyro.clear_param_store()
full_guide = AutoDelta(poutine.block(full_model, hide=["assignment"]))
loss_func = pyro.infer.TraceEnum_ELBO(max_plate_nesting=3)
svi = SVI(full_model, full_guide, optim, loss=loss_func)

svi.loss(full_model, full_guide, my_data)

The error I receive is:

ValueError: Error while packing tensors at site 'A':
  Invalid tensor shape.
  Allowed dims: -3
  Actual shape: (100, 1, 5)
  Try adding shape assertions for your model's sample values and distribution parameters.
  Trace Shapes:                
   Param Sites:                
    AutoDelta.A     100     1 5
AutoDelta.probs 100   1     1 2
    AutoDelta.B         10000 1
  Sample Sites:                
         A dist 100   1     5 |
          value 100   1     5 |
       log_prob 100   1     5 |

There seems to be a problem with the shape of A, I am not even sure whether parallel enumeration is possible in this case.

It might be distantly related to the following forum question, although their solution has not worked for me:

I would be greatly appreciative of any suggestions you might be able to provide! Thank you in advance!

Alex

P.S. Very impressed with pyro, great work!

I also ecounteres a similar problem. Any comment?

I think you need to add .to_event(1) to the Normal distribution here:

i.e.:

    with plate3:
        A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]), 
                                         torch.tensor([1., 0.8, 0.6, 0.4, 0.2])).to_event(1))

One way to think about it is that every dimension of a distribution needs to be either a batch dimension or an event dimension. with plate3 declares a batch dim (dim=-3) with size=100 and automatically expands your distribution if necessary (as you can see from the shapes in the error message). Event dims are declared using .to_event method on distributions. Tensor shapes in Pyro has a more in-depth explanation of batch shapes and event shapes in Pyro. Hope this helps.

Thank you for the quick reply! I realize that in creating my toy model above, I made a mistake in communicating my problem. The dimensions of plate1 should actually match that of variable A, which is not an event dimension. Both should be of dimension size 5. I’ve corrected this in the original post.

With that, my original problem still is unsolved, as I don’t think I want to send A to_event. I am able to get the code running for the toy by using to_event(1) and then indexing later on, but I don’t think this is necessarily good practice, or whether it will translate to complex models?

@config_enumerate
def full_model(data):
    plate1 = pyro.plate("plate1", 5, dim=-1)
    plate2 = pyro.plate("plate2", data.shape[-2], dim=-2)
    plate3 = pyro.plate("plate3", data.shape[-3], dim=-3)
    
    with plate3:
        A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]), 
                                         torch.tensor([1., 0.8, 0.6, 0.4, 0.2])).to_event(1))
        
        probs = pyro.sample('probs', dist.Dirichlet(torch.tensor([30., 70.])))
        assignment = pyro.sample('assignment', dist.Categorical(probs), infer={"enumerate": "parallel"})           
    
    with plate2:
        B = pyro.sample("B", dist.Uniform(0., 1.))
    
    prediction = torch.where(assignment==1, 
                             (A[:, :, 0, :]*B).sum(-1).unsqueeze(-1), 
                             (A[:, :, 0, :]*-B).sum(-1).unsqueeze(-1))
    
    with plate3, plate2:
        X = pyro.sample("X", dist.Gamma(torch.exp(prediction),
                                        torch.tensor(1.)), obs=data)

Thank you for your help and patience!

Then I think you should use plate1 when sampling variable A:

    with plate1:
        A = pyro.sample("A", dist.Normal(torch.tensor([0., 0.2, 0.4, 0.6, 0.8]), 
                                         torch.tensor([1., 0.8, 0.6, 0.4, 0.2])))