Global latent variables - numpyro crash

Trying the following model, gives an error I don’t understand:

Goal is to create a Dirichlet process, where latent parameters are global.

Toy example.

This works;

def model(data):
    conc = numpyro.sample('conc', Gamma(jnp.ones(kConditions), 10))
    
    # isn't used in the data plate
    with numpyro.plate("beta_plate", nHypotheses-1):
        beta = numpyro.sample("beta", Beta(1, alpha))

    with numpyro.plate("prob_plate", nHypotheses):
        probs = numpyro.sample("probs", Dirichlet(conc))

    with numpyro.plate("data", N):
        #z = numpyro.sample("z", Categorical(mix_weights(beta)))
        return numpyro.sample("obs", Multinomial(probs=probs[0]), obs=data)

If I uncomment the z = this breaks. I haven’t spent much time with pyro internals, so I don’t understand the error:

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params)
    255     # Handle possible vectorization
    256     if rng_key.ndim == 1:
--> 257         (init_params, pe, z_grad), is_valid = _find_valid_params(rng_key, exit_early=True)
    258     else:
    259         (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in _find_valid_params(rng_key, exit_early)
    243             # Early return if valid params found. This is only helpful for single chain,
    244             # where we can avoid compiling body_fn in while_loop.
--> 245             _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    246             if not_jax_tracer(is_valid):
    247                 if device_get(is_valid):
`
~/miniconda3/lib/python3.7/site-packages/numpyro/infer/util.py in potential_energy(model, model_args, model_kwargs, params, enum)
    155     substituted_model = substitute(model, substitute_fn=partial(_unconstrain_reparam, params))
    156     # no param is needed for log_density computation because we already substitute
--> 157     log_joint, model_trace = log_density_(substituted_model, model_args, model_kwargs, {})
    158     return - log_joint
    159 

~/miniconda3/lib/python3.7/site-packages/numpyro/contrib/funsor/infer_util.py in log_density(model, model_args, model_kwargs, params)
    122     model = substitute(model, data=params)
    123     with plate_to_enum_plate():
--> 124         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    125     log_factors = []
    126     time_to_factors = defaultdict(list)  # log prob factors

~/miniconda3/lib/python3.7/site-packages/numpyro/handlers.py in get_trace(self, *args, **kwargs)
    156         :return: `OrderedDict` containing the execution trace.
    157         """
--> 158         self(*args, **kwargs)
    159         return self.trace
    160 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __call__(self, *args, **kwargs)
     66     def __call__(self, *args, **kwargs):
     67         with self:
---> 68             return self.fn(*args, **kwargs)
     69 
     70 

~/miniconda3/lib/python3.7/site-packages/numpyro/primitives.py in __exit__(self, *args, **kwargs)
     55 
     56     def __exit__(self, *args, **kwargs):
---> 57         assert _PYRO_STACK[-1] is self
     58         _PYRO_STACK.pop()
     59 

AssertionError: 

If I add print statements after conc = and probs = . I see the error is thrown after probs becomes a Traced<ConcreteArray.

Thanks in advance.

Hi @akotlar, typically, this error happens when there are missing plate statements for models with discrete latent variables. In your model, the first statement should be enclosed in a plate statement:

conc = numpyro.sample('conc', Gamma(jnp.ones(kConditions), 10))

I guess adding plate there will resolve the issue but if not, please let me know.

In upcoming release, we will add some validation mechanisms, so this kind of issue can be early detected. Could you help me open an FR in github so we can make sure that that validation mechanism will cover your usage case?

1 Like

Ah this rule applies to models with discrete random variables (Categorical here), but not those with continuous r.v only? I noticed the eight schools example has samples drawn outside of plate context.

Thanks, @fehiepsi for explaining. Could you elaborate a bit more on why plates are needed even for global variables in case discrete latent variables exist in the model? And when exactly?
In the discrete latent variables example, http://num.pyro.ai/en/latest/examples/annotation.html are also some models having the global variable

pi = numpyro.sample("pi", dist.Dirichlet(jnp.ones(num_classes)))

which is also not within a plate. Why does it work in this case but not in the example of @akotlar?

Hi @FlorianWilhelm For enumeration, we leverage all the math to marginalize out the discrete latent variables to funsor, which needs to know the named dimensions of log_prob at each sample site. We use plate to declare those “named” dimensions. This requirement applies to both local and global variables.

In annotation example,

dist.Dirichlet(jnp.ones(num_classes))

has batch_shape = () (i.e. log_prob will be a scalar given a dirichlet sample). In the above example, the site conc has batch_shape = (kConditions,), so we need a plate statement to give a name to that batch dimension.

1 Like

Ahh, now I understand! Thanks, @fehiepsi, also for your great work with Numpyro. I really do think that it will turn out as a game-changer in terms of ML modeling over the years. It’s quite tough to wrap your head around Numpyro but it’s well worth it :slight_smile:

2 Likes