[Pyro] Got runtime error when using hmc / mcmc together with sequential enumeration

I am trying to implement a CRBD model, which contains both continuous and discrete random variables, described in this paper and apply HMC to it.

But I got a runtime error that says

ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).

Is this I use Pyro’s HMC in the wrong way, or it’s Pyro’s HMC not compatible with {'enumerate': 'sequential'}?

What should I do to apply HMC together with {'enumerate': 'sequential'} to this CRBD model?

Thanks for helping.


import argparse
import pyro.distributions as dist
import pyro
import torch
from pyro.infer import MCMC, NUTS
import sys
sys. setrecursionlimit(32767)

def gosExtince(prefix, time, la, mu):
    waitingTime = pyro.sample(f"{prefix}/waitingTime", dist.Exponential(la))
    if waitingTime > time:
        b_waitingTime = False
        isSpeciation = pyro.sample(f"{prefix}/isSpeciation", dist.Bernoulli(la / (la + mu)), infer={'enumerate': 'sequential'})
        # ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).
        if isSpeciation: # https://pyro.ai/examples/enumeration.html
            x = gosExtince(f"{prefix}/x", time - waitingTime, la, mu)
            y = gosExtince(f"{prefix}/y", time - waitingTime, la, mu)
            b_isSpeciation = x and y
            b_isSpeciation = True
        b_waitingTime = b_isSpeciation
    return b_waitingTime

def model(time):
    la = pyro.sample("lamda", dist.Gamma(1, 1))
    mu = pyro.sample("mu", dist.Gamma(1, 1))
    obs = gosExtince(".", time, la, mu)
    pyro.factor("obs", torch.ones(1) if obs else -torch.inf)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', choices=["HMC", "MAPPL"],  required=True)
    parser.add_argument('--time', type=float, required=True)
    parser.add_argument('--progress_bar', action='store_true')
    parser.add_argument('--num_chains', type=int, required=True)
    parser.add_argument('--warmup_steps', type=int, required=True)
    parser.add_argument('--num_samples', type=int, required=True)
    args = parser.parse_args()

    if args.config == "HMC":
        nuts_kernel = NUTS(model)

    mcmc = MCMC(
        disable_progbar=not args.progress_bar

if __name__ == '__main__':