I am trying to implement a CRBD model, which contains both continuous and discrete random variables, described in this paper and apply HMC to it.
But I got a runtime error that says
ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).
Is this I use Pyro’s HMC in the wrong way, or it’s Pyro’s HMC not compatible with {'enumerate': 'sequential'}
?
What should I do to apply HMC together with {'enumerate': 'sequential'}
to this CRBD model?
Thanks for helping.
Code:
import argparse
import pyro.distributions as dist
import pyro
import torch
from pyro.infer import MCMC, NUTS
import sys
sys. setrecursionlimit(32767)
def gosExtince(prefix, time, la, mu):
waitingTime = pyro.sample(f"{prefix}/waitingTime", dist.Exponential(la))
if waitingTime > time:
b_waitingTime = False
else:
isSpeciation = pyro.sample(f"{prefix}/isSpeciation", dist.Bernoulli(la / (la + mu)), infer={'enumerate': 'sequential'})
# ValueError: Continuous inference cannot handle discrete sample site './isSpeciation'. Consider enumerating that variable as documented in https://pyro.ai/examples/enumeration.html . If you are already enumerating, take care to hide this site when constructing an autoguide, e.g. guide = AutoNormal(poutine.block(model, hide=['./isSpeciation'])).
if isSpeciation: # https://pyro.ai/examples/enumeration.html
x = gosExtince(f"{prefix}/x", time - waitingTime, la, mu)
y = gosExtince(f"{prefix}/y", time - waitingTime, la, mu)
b_isSpeciation = x and y
else:
b_isSpeciation = True
b_waitingTime = b_isSpeciation
return b_waitingTime
def model(time):
la = pyro.sample("lamda", dist.Gamma(1, 1))
mu = pyro.sample("mu", dist.Gamma(1, 1))
obs = gosExtince(".", time, la, mu)
pyro.factor("obs", torch.ones(1) if obs else -torch.inf)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', choices=["HMC", "MAPPL"], required=True)
parser.add_argument('--time', type=float, required=True)
parser.add_argument('--progress_bar', action='store_true')
parser.add_argument('--num_chains', type=int, required=True)
parser.add_argument('--warmup_steps', type=int, required=True)
parser.add_argument('--num_samples', type=int, required=True)
args = parser.parse_args()
print(args)
if args.config == "HMC":
nuts_kernel = NUTS(model)
mcmc = MCMC(
nuts_kernel,
warmup_steps=args.warmup_steps,
num_samples=args.num_samples,
num_chains=args.num_chains,
disable_progbar=not args.progress_bar
)
mcmc.run(
args.time
)
mcmc.print_summary()
if __name__ == '__main__':
main()