Hi, I am new to pyro and would like to create a Multivariate Poisson Mixture Model using pyro.
Based on the GMM tutorial, I attempted this:
import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
import matplotlib.pyplot as plt
%matplotlib inline
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete, Trace_ELBO
from pyro.infer import Predictive
from pyro.ops.indexing import Vindex
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith('1.9.1')
K = 4 # Fixed number of components.
n_obs = 100
n_features = 5
real_assignment = torch.randint(high=K, size=(100,1))
real_means = torch.tensor([1,5,10,20])
data = torch.tensor(scipy.stats.poisson.rvs(mu=real_means[real_assignment], size=(n_obs,n_features))).float()
@config_enumerate
def model(data):
weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
with pyro.plate("components", K):
means = pyro.sample("means", dist.Uniform(1.0, 20.0))
with pyro.plate("data", len(data)):
assignment = pyro.sample("assignment", dist.Categorical(weights))
expanded_means = means[assignment][:,None].expand(-1,data.shape[1])
pyro.sample("obs", dist.Poisson(expanded_means).to_event(1), obs=data)
optim = pyro.optim.Adam({"lr": 0.1, "betas": [0.8, 0.99]})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
global_guide = AutoDelta(
poutine.block(model, expose=["weights", "means"]))
svi = SVI(model, global_guide, optim, loss=elbo)
losses = []
for i in range(200):
loss = svi.step(data)
losses.append(loss)
print("." if i % 100 else "\n", end="")
print(global_guide(data))
In this model on event is 5 draws of the same Poisson distribution.
I want the latent categorical variables to be marginalised out and if I understood correctly this can be done using TraceEnum_ELBO
.
Get this error when running the code :
RuntimeError: expand(torch.FloatTensor{[4, 1, 1]}, size=[-1, 5]): the number of sizes provided (2) must be greater or equal to the number of dimensions in the tensor (3)
Trace Shapes:
Param Sites:
Sample Sites:
weights dist | 4
value | 4
components dist |
value 4 |
means dist 4 |
value 4 |
data dist |
value 100 |
assignment dist 100 |
value 4 1 |
Can someone help me please ?