Cannot find valid initial parameters when using NUTS for simple Gaussian Mixture Model in NumPyro

Hi Pyro Devs!

I have made a simple Gaussian mixture model to see if I can understand how the discrete-site enumeration works:


def gmm(data, num_components=3):
    mus = numpyro.sample('mus', dist.Normal(jnp.zeros(num_components),
                                            jnp.ones(num_components) * 100.).to_event(1))
    sigmas = numpyro.sample('sigmas', dist.HalfNormal(jnp.zeros(num_components) * 100.).to_event(1))
    mixture_probs = numpyro.sample('mixture_probs', dist.Dirichlet(
        jnp.ones(num_components) / num_components))
    with numpyro.plate('data', len(data), dim=-1):
        z = numpyro.sample('z', dist.Categorical(mixture_probs))
        numpyro.sample('ll', dist.Normal(mus[z], sigmas[z]), obs=data)

When I however run it as follows

def main(_args):
    data = generate_data()
    nuts = NUTS(gmm)
    init_rng_key = PRNGKey(1273)
    mcmc = MCMC(nuts, 100, 1000)
    res = mcmc.run(init_rng_key, data)
    mcmc.print_summary()

I get an error:

Traceback (most recent call last):
  File "/Users/asal/Documents/SourceControl/jeffreys/enum_gmm.py", line 59, in <module>
    main(sys.argv)
  File "/Users/asal/Documents/SourceControl/jeffreys/enum_gmm.py", line 48, in main
    res = mcmc.run(init_rng_key, data)
  File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 446, in run
    states_flat, last_state = partial_map_fn(map_args)
  File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 312, in _single_chain_mcmc
    init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
  File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 455, in init
    init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
  File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 408, in _init_state
    init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
  File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/util.py", line 430, in initialize_model
    raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
RuntimeError: Cannot find valid initial parameters. Please check your model again.

Any idea what could cause this type of error and how to debug it?

Thanks!

FYI, generate_data is very simple:


def generate_data(num_samples=1000):
    zs = stats.multinomial.rvs(1, (0.5, 0.35, 0.15), num_samples).argmax(-1)
    mus = np.array([-100., 0., 34.])
    sigmas = np.array([100., 1., 0.3])
    return stats.norm.rvs(loc=mus[zs], scale=sigmas[zs], size=num_samples)

@ahmadsalim you need to use a positive scale for HalfNormal distribution. To debug issues like this, you can use

with numpyro.validation_enabled():
    res = mcmc.run(init_rng_key, data)

which will point out where parameters get wrong values. :slight_smile:

2 Likes

Thanks a lot! I totally missed that I used torch.zeros instead of torch.ones :smiley:
You can see it does not make sense that I multiply by 100 if it was zeros either.

Yeah, it is hard to see only by looking at the model. I just made this PR which can point out where things are wrong. :wink:

1 Like

Very cool!