Implementing MCMC of GMM with same mean and different variance

I’m new to pyro. I’m trying to set up an MCMC of a toy problem - Gaussian Mixture Model with same mean but different variance (e.g., sig1 = 1, sig2=0.1). Prior over mean is uniform, and so posterior is also a GMM.

Not sure how to do it. I tried the following code:

def model(obs=None):
    mu = pyro.sample("mu", dist.Uniform(-10*torch.ones(dim), high=10*torch.ones(dim)))
    sigma = pyro.param("sigma", lambda: torch.ones(()) if torch.randn(()) < 0.5 else 0.1*torch.ones(()), constraint=constraints.positive)
    return pyro.sample("obs", dist.Normal(mu, sigma), obs=obs)

nuts_kernel = NUTS(model)
mcmc = MCMC(
samples = mcmc.get_samples()

But got bad results (KDE plot of posterior looks bad).

Nevermind, I realized I initiate the sigma only once with the lambda expression - this seems to be working good:

def model(obs=None):
    mu = pyro.sample("theta", dist.Uniform(-10*torch.ones(dim), high=10*torch.ones(dim)))
    ber = pyro.sample("mix", dist.Bernoulli(0.5))
    sigma = ber*torch.ones(()) + (1-ber)*0.1*torch.ones(())
    return pyro.sample("obs", dist.Normal(mu, sigma), obs=obs)