Global latent variables - numpyro crash

Trying the following model, gives an error I don’t understand:

Goal is to create a Dirichlet process, where latent parameters are global.

Toy example.

This works;

def model(data):
    conc = numpyro.sample('conc', Gamma(jnp.ones(kConditions), 10))
    
    # isn't used in the data plate
    with numpyro.plate("beta_plate", nHypotheses-1):
        beta = numpyro.sample("beta", Beta(1, alpha))

    with numpyro.plate("prob_plate", nHypotheses):
        probs = numpyro.sample("probs", Dirichlet(conc))

    with numpyro.plate("data", N):
        #z = numpyro.sample("z", Categorical(mix_weights(beta)))
        return numpyro.sample("obs", Multinomial(probs=probs[0]), obs=data)

If I uncomment the z = this breaks. I haven’t spent much time with pyro internals, so I don’t understand the error:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params)
    255     # Handle possible vectorization
    256     if rng_key.ndim == 1:
--> 257         (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
    258     else:
    259         (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    243             # Early return if valid params found. This is only helpful for single chain,
    244             # where we can avoid compiling body_fn in while_loop.
--> 245             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    246             if not_jax_tracer(is_valid):
    247                 if device_get(is_valid):
`
~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    155     substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
    156     # no param is needed for log_density computation because we already substitute
--> 157     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    158     return - log_joint
    159 

~/miniconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    122     model = substitute(model, data=params)
    123     with plate_to_enum_plate():
--> 124         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    125     log_factors = []
    126     time_to_factors = defaultdict(list)  # log prob factors

~/miniconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    156         :return: `OrderedDict` containing the execution trace.
    157         """
--> 158         self(*args, **kwargs)
    159         return self.trace
    160 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     66     def __call__(self, *args, **kwargs):
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 
     70 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __exit__(self, *args, **kwargs)
     55 
     56     def __exit__(self, *args, **kwargs):
---> 57         assert _PYRO_STACK[-1] is self
     58         _PYRO_STACK.pop()
     59 

AssertionError: 

If I add print statements after conc = and probs = . I see the error is thrown after probs becomes a Traced<ConcreteArray.

Thanks in advance.

Hi @akotlar, typically, this error happens when there are missing plate statements for models with discrete latent variables. In your model, the first statement should be enclosed in a plate statement:

conc = numpyro.sample('conc', Gamma(jnp.ones(kConditions), 10))

I guess adding plate there will resolve the issue but if not, please let me know.

In upcoming release, we will add some validation mechanisms, so this kind of issue can be early detected. Could you help me open an FR in github so we can make sure that that validation mechanism will cover your usage case?

1 Like

Ah this rule applies to models with discrete random variables (Categorical here), but not those with continuous r.v only? I noticed the eight schools example has samples drawn outside of plate context.