Pyro batched mixture model

Hi! I am trying to implement a mixture of Gaussians inside additional batch dimensions. In the below example I tried to add one group dimension around the mixtures. However, I run into issues

Inside pyro.plate(group, 5, dim=-1) invalid shape of pyro.param(AutoDelta.loc, ..., event_dim=0): torch.Size([5, 4])

indicating a mismatch between the guide and the model. Unfortunately, however much I try, either the AutoDelta or the elbo evaluations of the model breaks due to inconsistent shapes.

I would appreciate some help on how to properly implement this. Thanks!

import torch
from torch import Tensor
from pyro.infer import autoguide as ag, config_enumerate
from pyro import distributions as D
import pyro
from pyro import poutine
import matplotlib.pyplot as plt
import numpy as np

num_groups = 5
Y1 = torch.tensor([0.16, 0.5, 0.3, 0.25, 0.2, 1.5, 1.6, 1.46, 1.7])
Y2 = torch.tensor([0.11, 0.21, 0.15, 0.18, 0.22, 1.9, 1.85, 1.76, 1.72])
Y = torch.cat((Y1, Y2))
Ys = Y.repeat(num_groups, 1) + torch.arange(num_groups)[:,None]
group = torch.cat((torch.zeros(len(Y1), dtype=int), torch.ones(len(Y2), dtype=int))).long()
groups = Y.repeat(num_groups, 1).long()
steps = torch.arange(num_groups).repeat_interleave(Y.size(0)).long()
    
def init_loc_fn(site, ngroup, ncomp):
    if site["name"] == "weights":
        return torch.ones(ngroup, ncomp) / ncomp
    if site["name"] == "loc":
        return torch.randn(ngroup, ncomp)
    if site["name"] == "scale":
        return torch.ones(ngroup, ncomp)
    
    raise ValueError(site["name"])

@config_enumerate
def mog(Y: Tensor, group: Tensor, num_groups: int, num_components: int):
        
    with pyro.plate("group", num_groups):
        weights = pyro.sample("weights", D.Dirichlet(torch.ones(num_components)))
        print(weights.shape)
        
        with pyro.plate("components", num_components):
            loc = pyro.sample("loc", D.Normal(1, 1.0))
            scale = pyro.sample("scale", D.HalfNormal(1.0))
            
        print(loc.shape)
        print(scale.shape)
        
    with pyro.plate("data", Y.size(0)):
        assignment = pyro.sample(
            "assignment",
            D.Categorical(weights[group])
        )
        print(assignment.shape)
        
        y = pyro.sample("Y", D.Normal(loc[group, assignment], scale[group, assignment]), obs=Y)
        print(y.shape)
        
pyro.clear_param_store()
num_components = 4
guide = ag.AutoDelta(
    poutine.block(mog, hide=["assignment"]),
    init_loc_fn=lambda site: init_loc_fn(site, num_groups, num_components)
)
adam = pyro.optim.Adam({"lr": 0.05})
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=2)
svi = pyro.infer.SVI(mog, guide, adam, elbo)
num_steps = 500

elbos = []
for step in range(num_steps):
    loss = svi.step(Ys.ravel(), steps, num_groups, num_components)
    elbos.append(loss)
    
map_params = guide.median()
plt.plot(elbos)

It looks like your plates are in an unintended order: pyro.plate("group") uses dim -1 and pyro.plate("components") uses dim -2, i.e. [component, group]. This is because Pyro’s automatic plate dim allocation assigns the first plate dim=-1, the second plate dim=-2, the third plate dim=-3 etc. In your model I’d recommend manually setting plate dims as you intend:

with pyro.plate("group", num_groups, dim=-2):
    with pyro.plate("components", num_components, dim=-1):
        ...

Another trick I find helpful is to use site["value"].shape for initial tensors in custom init_loc_fn() implementions. I suspect your manual shape specification may have made debugging harder. You could try this version which might have made your plate bug surface earlier:

def init_loc_fn(site, ngroup, ncomp):
    shape = site["fn"].shape()  # EDITED
    if site["name"] == "weights":
        return torch.ones(shape) / ncomp
    if site["name"] == "loc":
        return torch.randn(shape)
    if site["name"] == "scale":
        return torch.ones(shape)
    raise ValueError(site["name"])

Good luck!

Thank you very much for repsonding @fritzo .

You are right in the plate ordering being wrong in the OP. I have tried re-ordering them but was afraid to dive into too much detail on what I tried previously.

Swapping the plate ordering like you suggested exposes the issue that I eluded to with " AutoDelta and elbo gives inconsistent shapes" comment. For the AutoDelta pass weights[group].shape == [90, 4], and for the TraceEnum_ELBO pass weights[group].shape == [90, 5, 4] which causes Shape mismatch inside plate('data') at site assignment dim -1, 90 vs 5.

This dimensions in the error is confusing to me, but I see how a batched Categorical should have the dimensions [5, 90, 4] (groups, obs, components), so something else has to be done here, but I can’t for the life of me figure out what. Any advice on this?

P.S. the shape = site["shape"] trick sounds great, but there is no “shape” key in site for me on pyro 1.8.1.

Sorry shape = site["fn"].shape() :blush:, I’ll edit the above comment