DiscreteHMCGibbs (enumeration) cannot detect any discrete latent

I am using numpyro version 0.12.1 on Mac.
I intend to use DiscreteHMCGibbs on a simple non-Markovian regime switching time series model.
I saw on the doc that the module can marginalize out a discrete latent in the model by setting the infer enumerate to parallel.
While I think the model contains discrete latent variables, the DiscreteHMCGibbs says it does not.

The (unrealistic) simplified version of the model I am working on is shown below:

def RS_AR1(y):
    theta = numpyro.sample("theta", dist.Normal().expand([2]).to_event(1))
    def transition_func(carry, y):
        prev_y = carry
        next_regime_prob = jax.nn.sigmoid(prev_y)
        current_regime = numpyro.sample("current_regime", dist.Bernoulli(next_regime_prob),
                                        infer={'enumerate': 'parallel'})
        current_theta = theta[current_regime]
        current_y = numpyro.sample("current_y", dist.Normal(current_theta * prev_y), obs = y)
        return (current_y), None
    carry_init = (y[0])
    scan(transition_func, carry_init, y[1:])  
ar1_kernel =  DiscreteHMCGibbs(HMC(RS_AR1))
mcmc_1 = MCMC(ar1_kernel, num_samples = 10000,num_warmup = 10000)
mcmc_1.run(random.PRNGKey(1), jnp.array(y))

Basically, the parameter for the time series model vary between two values determined by probability determined by the sigmoid of the output in the previous time step.

I got “AssertionError: Cannot detect any discrete latent variables in the model.”

How can I fix this? Or Any suggestion on this.

Note: 1. I removed the “infer=…”. the code runs, but it does not marginalize out “current_regime.”
2. got the same problem for NUTS innner kernel. Also the same problem for MixedHMC.
3. I also tried on “Markov Switching” (with parallel enumerate), and it works well.

Thank you for any helps

DiscreteHMCGibbs does gibbs moves w.r.t. discrete latents, i.e. the discrete latents need to be part of the model and cannot be enumerated out.


Thank you for your response.

Could you explain me more about the “marginalize out” part of the docs shown below.
It does not seem to be the the same thing to what you mentioned, and I have to use it.


for discrete latent variables marginalize and enumerate are the same thing: it involves explicitly enumerating and summing out all discrete states

1 Like