I have a model which randomly sizes a Binomial and then samples from it. However, it appears to give an error when the total_count ends up less than the observed value. Shouldn’t this just assign this branch a zero probability and continue sampling? The stack trace is below. Thanks,
Ravi
/usr/local/lib/python2.7/dist-packages/pyro/infer/svi.pyc in step(self, *args, **kwargs)
73 # get loss and compute gradients
74 with poutine.trace(param_only=True) as param_capture:
—> 75 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
76
77 params = set(site[“value”].unconstrained()
/usr/local/lib/python2.7/dist-packages/pyro/infer/trace_elbo.pyc in loss_and_grads(self, model, guide, *args, **kwargs)
105 elbo = 0.0
106 # grab a trace from the generator
–> 107 for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
108 elbo_particle = 0
109 surrogate_elbo_particle = 0
/usr/local/lib/python2.7/dist-packages/pyro/infer/trace_elbo.pyc in _get_traces(self, model, guide, *args, **kwargs)
66 model_trace = prune_subsample_sites(model_trace)
67
—> 68 model_trace.compute_log_prob()
69 guide_trace.compute_score_parts()
70 if is_validation_enabled():
/usr/local/lib/python2.7/dist-packages/pyro/poutine/trace_struct.pyc in compute_log_prob(self, site_filter)
250 except KeyError:
251 args, kwargs = site[“args”], site[“kwargs”]
–> 252 site_log_p = site[“fn”].log_prob(site[“value”], *args, **kwargs)
253 site_log_p = scale_tensor(site_log_p, site[“scale”])
254 site[“log_prob”] = site_log_p
/usr/local/lib/python2.7/dist-packages/pyro/distributions/binomial.pyc in log_prob(self, value)
101 def log_prob(self, value):
102 if self._validate_args:
–> 103 self._validate_sample(value)
104 log_factorial_n = torch.lgamma(self.total_count + 1)
105 log_factorial_k = torch.lgamma(value + 1)
/usr/local/lib/python2.7/dist-packages/torch/distributions/distribution.pyc in _validate_sample(self, value)
219
220 if not self.support.check(value).all():
–> 221 raise ValueError(‘The value argument must be within the support’)
222
223 def repr(self):
ValueError: The value argument must be within the support