Variance vanishing in Hierarchical model

Hello everybody, I am facing a problem that I am sure someone might have faced already. It is quite general and related to Hierarchical models. When I try to train a simple toy model like this one: (implemented in numpyro, but I think the problem is general)

data  = np.arange(1, 11)
data = np.float32(data)
data

def model(X):

    mu = numpyro.param("mu", 0.)
    sigma = numpyro.param("sigma", 1., constraint=dist.constraints.positive)
    with numpyro.plate("data", len(X)): 
        a = numpyro.sample("a", dist.Normal(mu, sigma))

    numpyro.sample("obs", dist.Normal(a, 1), obs=X)

def guide(X):
    pass

What happens is that after training, I get that the mu parameter is perfect, however the variance vanishes (as the distribution generating the latent variable was a Dirac Delta). I understood that this problem is given by the fact that the latent variable ‘a’, gets sampled randomly every time, and therefore it has random ordering. So even if parameters mu and sigma where exactly the ‘right’ ones of the data, once the likelihood is evaluated with the sample, we could get a quite low value of it, only because of the ordering.

For this easy toy model I was able to fix this pathological behaviour with a sorting line (taking advantage that the data is also ordered):

def model2(X):

    mu = numpyro.param("mu", 0.)
    sigma = numpyro.param("sigma", 1., constraint=dist.constraints.positive)
    with numpyro.plate("data", len(X)): 
        a = numpyro.sample("a", dist.Normal(mu, sigma))

    # now re order the sample
    a_sorted = jnp.sort(a)

    numpyro.sample("obs", dist.Normal(a_sorted, 1), obs=X)

And get the expected mu, and sigma of the data (mean =5.5 , std=2.87).

However I am having a hard time doing the same for more complicated models.

  • Has anyone faced this problem? Is there a built in pyro function or workaround to solve this problem?
  • (side question) I often use SVI with a pass guide, to perform MLE with pyro. I find it’s a very intuitive and easy way to test you models before switching to more colpicated ones. Am I abusing somehow of this feature? Is this the best way to perform a simple MLE? Or should it be implemented directly in JAX/torch?

Have a good day!
A

is X fixed are you using mini-batches of data?

https://pyro.ai/examples/svi_part_ii.html

No minibatches the data is defined in the first lines:

data  = np.arange(1, 11)
data = np.float32(data)

I am using all data points at every iteration