Hi Pyro Devs!
I have made a simple Gaussian mixture model to see if I can understand how the discrete-site enumeration works:
def gmm(data, num_components=3):
mus = numpyro.sample('mus', dist.Normal(jnp.zeros(num_components),
jnp.ones(num_components) * 100.).to_event(1))
sigmas = numpyro.sample('sigmas', dist.HalfNormal(jnp.zeros(num_components) * 100.).to_event(1))
mixture_probs = numpyro.sample('mixture_probs', dist.Dirichlet(
jnp.ones(num_components) / num_components))
with numpyro.plate('data', len(data), dim=-1):
z = numpyro.sample('z', dist.Categorical(mixture_probs))
numpyro.sample('ll', dist.Normal(mus[z], sigmas[z]), obs=data)
When I however run it as follows
def main(_args):
data = generate_data()
nuts = NUTS(gmm)
init_rng_key = PRNGKey(1273)
mcmc = MCMC(nuts, 100, 1000)
res = mcmc.run(init_rng_key, data)
mcmc.print_summary()
I get an error:
Traceback (most recent call last):
File "/Users/asal/Documents/SourceControl/jeffreys/enum_gmm.py", line 59, in <module>
main(sys.argv)
File "/Users/asal/Documents/SourceControl/jeffreys/enum_gmm.py", line 48, in main
res = mcmc.run(init_rng_key, data)
File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 446, in run
states_flat, last_state = partial_map_fn(map_args)
File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/mcmc.py", line 312, in _single_chain_mcmc
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 455, in init
init_params = self._init_state(rng_key_init_model, model_args, model_kwargs, init_params)
File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/hmc.py", line 408, in _init_state
init_params, potential_fn, postprocess_fn, model_trace = initialize_model(
File "/usr/local/Caskroom/miniconda/base/envs/jeffreys/lib/python3.8/site-packages/numpyro/infer/util.py", line 430, in initialize_model
raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
RuntimeError: Cannot find valid initial parameters. Please check your model again.
Any idea what could cause this type of error and how to debug it?
Thanks!