Hi, everybody. I am relatively new to Pyro and I am trying to fit data from a simple generative model.
Basically, I have N samples with K count variables (in this example 2). Some of the K variables differ for a latent integer factor. Each sample has also a Gamma specific factor.
pyro.set_rng_seed(3)
N = 1000
gamma = dist.Gamma(6,2).sample([N])
groups = torch.ones((2,N))
groups[0,int(N*0.6):N] = 3
baselines = torch.tensor([30,50]).reshape([2,1])
data = dist.Poisson(gamma * groups * baselines).sample()
So a corresponding model, assuming the baselines are observed could be
import pyro
from pyro.ops.indexing import Vindex
from torch.distributions import constraints
import torch
def model(data, baselines):
weights = pyro.sample('mixture_weights', dist.Dirichlet(torch.ones(2) / 2))
with pyro.plate('segments', 2):
with pyro.plate('groups', 2):
# suppose here I don't know the number of groups and I set an upper bound
group = pyro.sample("cc", dist.Categorical(torch.ones(4) / 4))
with pyro.plate('data', 1000):
theta = pyro.sample('norm_factor', dist.Gamma(3,1))
assignment = pyro.sample('assignment', dist.Categorical(weights), infer={"enumerate": "parallel"})
for i in pyro.plate('segments2', 2):
pyro.sample('obs_{}'.format(i), dist.Poisson((Vindex(group)[...,assignment,i] * theta * baselines[i])+ 1e-8), obs=data[i, :])
And as a guide
def guide(data, baselines):
mix_weights = pyro.param("param_weights", lambda: torch.ones(2) / 2,
constraint=constraints.simplex)
hidden_weights = pyro.param("param_hidden_weights", lambda: torch.ones(4) / 4,
constraint=constraints.simplex)
gamma_MAP = pyro.param("param_gamma_MAP", lambda : torch.mean((data/baselines), axis = 0),
constraint=constraints.positive)
pyro.sample('mixture_weights', dist.Dirichlet(mix_weights))
with pyro.plate('segments', 2):
with pyro.plate('groups', 2):
pyro.sample("cc", dist.Categorical(hidden_weights))
with pyro.plate('data', N):
pyro.sample('norm_factor', dist.Delta(gamma_MAP))
The problem is that the loss does not converge and shows very bad behavior (I think for the variance associated with the sampling). If I substitute the Categorical with a continuous variable (like a LogNormal) the inference is fine. However, is there a way to reduce the variance or enumerating also the inner distribution?