Hi! I am trying to implement a mixture of Gaussians inside additional batch dimensions. In the below example I tried to add one group
dimension around the mixtures. However, I run into issues
Inside pyro.plate(group, 5, dim=-1) invalid shape of pyro.param(AutoDelta.loc, ..., event_dim=0): torch.Size([5, 4])
indicating a mismatch between the guide and the model. Unfortunately, however much I try, either the AutoDelta
or the elbo
evaluations of the model breaks due to inconsistent shapes.
I would appreciate some help on how to properly implement this. Thanks!
import torch
from torch import Tensor
from pyro.infer import autoguide as ag, config_enumerate
from pyro import distributions as D
import pyro
from pyro import poutine
import matplotlib.pyplot as plt
import numpy as np
num_groups = 5
Y1 = torch.tensor([0.16, 0.5, 0.3, 0.25, 0.2, 1.5, 1.6, 1.46, 1.7])
Y2 = torch.tensor([0.11, 0.21, 0.15, 0.18, 0.22, 1.9, 1.85, 1.76, 1.72])
Y = torch.cat((Y1, Y2))
Ys = Y.repeat(num_groups, 1) + torch.arange(num_groups)[:,None]
group = torch.cat((torch.zeros(len(Y1), dtype=int), torch.ones(len(Y2), dtype=int))).long()
groups = Y.repeat(num_groups, 1).long()
steps = torch.arange(num_groups).repeat_interleave(Y.size(0)).long()
def init_loc_fn(site, ngroup, ncomp):
if site["name"] == "weights":
return torch.ones(ngroup, ncomp) / ncomp
if site["name"] == "loc":
return torch.randn(ngroup, ncomp)
if site["name"] == "scale":
return torch.ones(ngroup, ncomp)
raise ValueError(site["name"])
@config_enumerate
def mog(Y: Tensor, group: Tensor, num_groups: int, num_components: int):
with pyro.plate("group", num_groups):
weights = pyro.sample("weights", D.Dirichlet(torch.ones(num_components)))
print(weights.shape)
with pyro.plate("components", num_components):
loc = pyro.sample("loc", D.Normal(1, 1.0))
scale = pyro.sample("scale", D.HalfNormal(1.0))
print(loc.shape)
print(scale.shape)
with pyro.plate("data", Y.size(0)):
assignment = pyro.sample(
"assignment",
D.Categorical(weights[group])
)
print(assignment.shape)
y = pyro.sample("Y", D.Normal(loc[group, assignment], scale[group, assignment]), obs=Y)
print(y.shape)
pyro.clear_param_store()
num_components = 4
guide = ag.AutoDelta(
poutine.block(mog, hide=["assignment"]),
init_loc_fn=lambda site: init_loc_fn(site, num_groups, num_components)
)
adam = pyro.optim.Adam({"lr": 0.05})
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=2)
svi = pyro.infer.SVI(mog, guide, adam, elbo)
num_steps = 500
elbos = []
for step in range(num_steps):
loss = svi.step(Ys.ravel(), steps, num_groups, num_components)
elbos.append(loss)
map_params = guide.median()
plt.plot(elbos)