Trying the following model, gives an error I don’t understand:
Goal is to create a Dirichlet process, where latent parameters are global.
Toy example.
This works;
def model(data):
conc = numpyro.sample('conc', Gamma(jnp.ones(kConditions), 10))
# isn't used in the data plate
with numpyro.plate("beta_plate", nHypotheses-1):
beta = numpyro.sample("beta", Beta(1, alpha))
with numpyro.plate("prob_plate", nHypotheses):
probs = numpyro.sample("probs", Dirichlet(conc))
with numpyro.plate("data", N):
#z = numpyro.sample("z", Categorical(mix_weights(beta)))
return numpyro.sample("obs", Multinomial(probs=probs[0]), obs=data)
If I uncomment the z =
this breaks. I haven’t spent much time with pyro internals, so I don’t understand the error:
~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params)
255 # Handle possible vectorization
256 if rng_key.ndim == 1:
--> 257 (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
258 else:
259 (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)
~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
243 # Early return if valid params found. This is only helpful for single chain,
244 # where we can avoid compiling body_fn in while_loop.
--> 245 _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
246 if not_jax_tracer(is_valid):
247 if device_get(is_valid):
`
~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
155 substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
156 # no param is needed for log_density computation because we already substitute
--> 157 log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
158 return - log_joint
159
~/miniconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
122 model = substitute(model, data=params)
123 with plate_to_enum_plate():
--> 124 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
125 log_factors = []
126 time_to_factors = defaultdict(list) # log prob factors
~/miniconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
156 :return: `OrderedDict` containing the execution trace.
157 """
--> 158 self(*args, **kwargs)
159 return self.trace
160
~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
66 def __call__(self, *args, **kwargs):
67 with self:
---> 68 return self.fn(*args, **kwargs)
69
70
~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __exit__(self, *args, **kwargs)
55
56 def __exit__(self, *args, **kwargs):
---> 57 assert _PYRO_STACK[-1] is self
58 _PYRO_STACK.pop()
59
AssertionError:
If I add print statements after conc =
and probs =
. I see the error is thrown after probs
becomes a Traced<ConcreteArray.
Thanks in advance.