Mixture Model with Categorical Mixture Distributions

Hi, everybody. I am relatively new to Pyro and I am trying to fit data from a simple generative model.
Basically, I have N samples with K count variables (in this example 2). Some of the K variables differ for a latent integer factor. Each sample has also a Gamma specific factor.


pyro.set_rng_seed(3)
N = 1000
gamma = dist.Gamma(6,2).sample([N])
groups = torch.ones((2,N))
groups[0,int(N*0.6):N] = 3
baselines = torch.tensor([30,50]).reshape([2,1])
data = dist.Poisson(gamma * groups * baselines).sample()

So a corresponding model, assuming the baselines are observed could be

import pyro
from pyro.ops.indexing import Vindex
from torch.distributions import constraints
import torch

def model(data, baselines):

    weights = pyro.sample('mixture_weights', dist.Dirichlet(torch.ones(2) / 2))

    with pyro.plate('segments', 2):
        with pyro.plate('groups', 2):
            # suppose here I don't know the number of groups and I set an upper bound
            group = pyro.sample("cc", dist.Categorical(torch.ones(4) / 4))

    with pyro.plate('data', 1000):
        theta = pyro.sample('norm_factor', dist.Gamma(3,1))
        assignment = pyro.sample('assignment', dist.Categorical(weights), infer={"enumerate": "parallel"})
        for i in pyro.plate('segments2', 2):
            pyro.sample('obs_{}'.format(i), dist.Poisson((Vindex(group)[...,assignment,i] * theta * baselines[i])+ 1e-8), obs=data[i, :])
  

And as a guide


def guide(data, baselines):

    mix_weights = pyro.param("param_weights", lambda: torch.ones(2) / 2,
                                   constraint=constraints.simplex)
    
    hidden_weights = pyro.param("param_hidden_weights", lambda:  torch.ones(4) / 4,
                           constraint=constraints.simplex)
    
    gamma_MAP = pyro.param("param_gamma_MAP", lambda : torch.mean((data/baselines), axis = 0),
                constraint=constraints.positive)
    
    pyro.sample('mixture_weights', dist.Dirichlet(mix_weights))
    
    with pyro.plate('segments', 2):
        with pyro.plate('groups', 2):
            pyro.sample("cc", dist.Categorical(hidden_weights))
    with pyro.plate('data', N):
        pyro.sample('norm_factor', dist.Delta(gamma_MAP))

The problem is that the loss does not converge and shows very bad behavior (I think for the variance associated with the sampling). If I substitute the Categorical with a continuous variable (like a LogNormal) the inference is fine. However, is there a way to reduce the variance or enumerating also the inner distribution?