Hi, I am new in using Pyro.
I want to accummulate the Poisson random number and make an inference based on observation with normal distribution.
However, I got the error KeyError: <class 'torch.distributions.constraints._IntegerGreaterThan'>
.
Here is my code
import pyro
import torch
from pyro.distributions import Uniform, Poisson, Normal
from pyro.infer.mcmc import MCMC
from pyro.infer.mcmc.nuts import NUTS
def example(niters, n0):
r0 = pyro.sample("r0", Uniform(0.0, 2.0))
ncum = n0
ncumms = [n0] # the cummulative values
for i in range(1,niters):
# get the new number
rate = ncum * r0
nnew = pyro.sample("new_%d"%i, Poisson(rate))
ncum = ncum + nnew
ncum_noise = pyro.sample("cum_%d"%i, Normal(ncum, 0.01))
ncumms.append(ncum)
return ncumms
# data
niters = 2
cumms = torch.tensor([80., 100])
n0 = cumms[0]
data = {}
for i in range(1,niters):
# the error disappear if I change "cum_%d" to "new_%d"
data["cum_%d"%i] = cumms[i]
conditioned_example = pyro.condition(example, data=data)
hmc_kernel = NUTS(conditioned_example, step_size=0.1)
posterior = MCMC(hmc_kernel,
num_samples=1000,
warmup_steps=50)
posterior.run(niters, n0)
If I change cum_%d
to new_%d
in the observation data, the error disappears.
And it seems that I can run the function example
many times without getting the error.
Here is the complete report on the error:
Warmup: 0%| | 0/1050 [00:00, ?it/s]Traceback (most recent call last):
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 140, in __call__
factory = self._registry[type(constraint)]
KeyError: <class 'torch.distributions.constraints._IntegerGreaterThan'>
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "test2.py", line 36, in <module>
posterior.run(niters, n0)
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/infer/mcmc/api.py", line 357, in run
for x, chain_id in self.sampler.run(*args, **kwargs):
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/infer/mcmc/api.py", line 168, in run
*args, **kwargs):
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/infer/mcmc/api.py", line 110, in _gen_samples
kernel.setup(warmup_steps, *args, **kwargs)
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py", line 266, in setup
self._initialize_model_properties(args, kwargs)
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/infer/mcmc/hmc.py", line 239, in _initialize_model_properties
skip_jit_warnings=self._ignore_jit_warnings,
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/pyro/infer/mcmc/util.py", line 387, in initialize_model
transforms[name] = biject_to(node["fn"].support).inv
File "/home/mfkasim/anaconda3/lib/python3.7/site-packages/torch/distributions/constraint_registry.py", line 143, in __call__
'Cannot transform {} constraints'.format(type(constraint).__name__))
NotImplementedError: Cannot transform _IntegerGreaterThan constraints