Hi Everyone,
I’m new to prob. programming and Pyro. I’m working on a mini demo,
- there’re several batches of eggs, the number of eggs in each batch follows a Poisson distribution.
- the probability for each egg hatched successfully is the same (a Bernoulli trial)
- therefore, the number of eggs hatched successfully in each batch follows a Binomial distribution.
I’m trying to infer the params for Poisson and Bernoulli priors. Here’s the code
import pyro
import pyro.distributions as dist
import torch
import torch.distributions.constraints as constraints
from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal, TraceEnum_ELBO, config_enumerate
from pyro.infer.mcmc import MCMC, NUTS
from pyro.optim import Adam, SGD
import pyro.contrib.autoguide as autoguide
# number of eggs hatched successfully per batch
items = torch.tensor([6, 5, 4, 8, 7, 2, 10, 12], dtype=torch.float)
def model(items):
# same hatch rate for every batch
hatch_rate = pyro.sample('hr', dist.Beta(1, 1))
# eggs laid per batch
with pyro.plate('eggs_laid', len(items)):
eggs = pyro.sample('e', dist.Poisson(10.))
# eggs hatched per batch
with pyro.plate('eggs_hatched', len(items)):
pyro.sample(
'obs',
dist.Binomial(
total_count=eggs,
probs=torch.ones(len(items)) * hatch_rate,
),
obs=items,
)
def guide(items):
# params for survival rate
hr_alpha = pyro.param(
'hr_alpha',
torch.tensor(1.),
constraint=constraints.positive,
)
hr_beta = pyro.param(
'hr_beta',
torch.tensor(1.),
constraint=constraints.positive,
)
hatch_rate = pyro.sample('hr', dist.Beta(hr_alpha, hr_beta))
# params for eggs per week
e_lambda = pyro.param(
'e_lambda', torch.tensor(10.),
constraint=constraints.positive,
)
with pyro.plate('eggs_laid', len(items)):
pyro.sample('e', dist.Poisson(e_lambda))
# training
num_epochs = 100
adam_params = {"lr": 0.05, "betas": (0.90, 0.999)}
optimizer = Adam(adam_params)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(num_epochs):
loss = svi.step(items)
if j % 50 == 0:
print("[epoch %04d] loss: %.4f" % (j + 1, loss))
The error is
[epoch 0001] loss: inf
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-67-42a6a4868962> in <module>()
6 pyro.clear_param_store()
7 for j in range(num_epochs):
----> 8 loss = svi.step(items)
9 if j % 50 == 0:
10 print("[epoch %04d] loss: %.4f" % (j + 1, loss))
13 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
126 # get loss and compute gradients
127 with poutine.trace(param_only=True) as param_capture:
--> 128 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
129
130 params = set(site["value"].unconstrained()
/usr/local/lib/python3.6/dist-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
124 loss = 0.0
125 # grab a trace from the generator
--> 126 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
127 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
128 loss += loss_particle / self.num_particles
/usr/local/lib/python3.6/dist-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
168 else:
169 for i in range(self.num_particles):
--> 170 yield self._get_trace(model, guide, args, kwargs)
/usr/local/lib/python3.6/dist-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
51 """
52 model_trace, guide_trace = get_importance_trace(
---> 53 "flat", self.max_plate_nesting, model, guide, args, kwargs)
54 if is_validation_enabled():
55 check_if_enumerated(guide_trace)
/usr/local/lib/python3.6/dist-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
42 and the model that is run against it.
43 """
---> 44 guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
45 if detach:
46 guide_trace.detach_()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
185 Calls this poutine and returns its trace instead of the function's return value.
186 """
--> 187 self(*args, **kwargs)
188 return self.msngr.get_trace()
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
169 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
170 exc = exc.with_traceback(traceback)
--> 171 raise exc from None
172 self.msngr.trace.add_node("_RETURN", name="_RETURN", type="return", value=ret)
173 return ret
/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
163 args=args, kwargs=kwargs)
164 try:
--> 165 ret = self.fn(*args, **kwargs)
166 except (ValueError, RuntimeError):
167 exc_type, exc_value, traceback = sys.exc_info()
<ipython-input-66-3a603925cfa8> in guide(items)
36
37 with pyro.plate('eggs_laid', len(items)):
---> 38 pyro.sample('e', dist.Poisson(e_lambda))
/usr/local/lib/python3.6/dist-packages/pyro/primitives.py in sample(name, fn, *args, **kwargs)
111 msg["is_observed"] = True
112 # apply the stack and return its return value
--> 113 apply_stack(msg)
114 return msg["value"]
115
/usr/local/lib/python3.6/dist-packages/pyro/poutine/runtime.py in apply_stack(initial_msg)
196 break
197
--> 198 default_process_message(msg)
199
200 for frame in stack[-pointer:]:
/usr/local/lib/python3.6/dist-packages/pyro/poutine/runtime.py in default_process_message(msg)
157 return msg
158
--> 159 msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
160
161 # after fn has been called, update msg to prevent it from being called again.
/usr/local/lib/python3.6/dist-packages/pyro/distributions/torch_distribution.py in __call__(self, sample_shape)
43 :rtype: torch.Tensor
44 """
---> 45 return self.rsample(sample_shape) if self.has_rsample else self.sample(sample_shape)
46
47 @property
/usr/local/lib/python3.6/dist-packages/torch/distributions/poisson.py in sample(self, sample_shape)
55 shape = self._extended_shape(sample_shape)
56 with torch.no_grad():
---> 57 return torch.poisson(self.rate.expand(shape))
58
59 def log_prob(self, value):
RuntimeError: invalid Poisson rate, expected rate to be non-negative
Trace Shapes:
Param Sites:
hr_alpha
hr_beta
e_lambda
Sample Sites:
hr dist |
value |
eggs_laid dist |
value 8 |
I’m not sure what went wrong. Why the ‘invalid Poisson rate’ since I’ve set the positive constraint? Any advice?
Thanks!