regarding “Dirichlet process mixture model tutorial”:
This section of the tutorial doesn’t seem to be reproducible anymore:
0%| | 0/1500 [00:00<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
~usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
215 try:
--> 216 log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
217 except ValueError:
~usr/local/lib/python3.6/dist-packages/torch/distributions/poisson.py in log_prob(self, value)
60 if self._validate_args:
---> 61 self._validate_sample(value)
62 rate, value = broadcast_all(self.rate, value)
~usr/local/lib/python3.6/dist-packages/torch/distributions/distribution.py in _validate_sample(self, value)
252 if not self.support.check(value).all():
--> 253 raise ValueError('The value argument must be within the support')
254
ValueError: The value argument must be within the support
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
<ipython-input-9-4e79a15ea6fc> in <module>
32 losses = []
33
---> 34 train(n_iter)
35
36 samples = torch.arange(0, 300).type(torch.float)
<ipython-input-6-c5d9a36e9211> in train(num_iterations)
7 pyro.clear_param_store()
8 for j in tqdm(range(num_iterations)):
----> 9 loss = svi.step(data)
10 losses.append(loss)
11
~usr/local/lib/python3.6/dist-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
126 # get loss and compute gradients
127 with poutine.trace(param_only=True) as param_capture:
--> 128 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
129
130 params = set(site["value"].unconstrained()
~usr/local/lib/python3.6/dist-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
124 loss = 0.0
125 # grab a trace from the generator
--> 126 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
127 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)
128 loss += loss_particle / self.num_particles
~usr/local/lib/python3.6/dist-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
168 else:
169 for i in range(self.num_particles):
--> 170 yield self._get_trace(model, guide, args, kwargs)
~usr/local/lib/python3.6/dist-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
51 """
52 model_trace, guide_trace = get_importance_trace(
---> 53 "flat", self.max_plate_nesting, model, guide, args, kwargs)
54 if is_validation_enabled():
55 check_if_enumerated(guide_trace)
~usr/local/lib/python3.6/dist-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
53 model_trace = prune_subsample_sites(model_trace)
54
---> 55 model_trace.compute_log_prob()
56 guide_trace.compute_score_parts()
57 if is_validation_enabled():
~usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
219 shapes = self.format_shapes(last_site=site["name"])
220 raise ValueError("Error while computing log_prob at site '{}':\n{}\n{}"
--> 221 .format(name, exc_value, shapes)).with_traceback(traceback)
222 site["unscaled_log_prob"] = log_p
223 log_p = scale_and_mask(log_p, site["scale"], site["mask"])
~usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in compute_log_prob(self, site_filter)
214 if "log_prob" not in site:
215 try:
--> 216 log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
217 except ValueError:
218 _, exc_value, traceback = sys.exc_info()
~usr/local/lib/python3.6/dist-packages/torch/distributions/poisson.py in log_prob(self, value)
59 def log_prob(self, value):
60 if self._validate_args:
---> 61 self._validate_sample(value)
62 rate, value = broadcast_all(self.rate, value)
63 return (rate.log() * value) - rate - (value + 1).lgamma()
~usr/local/lib/python3.6/dist-packages/torch/distributions/distribution.py in _validate_sample(self, value)
251
252 if not self.support.check(value).all():
--> 253 raise ValueError('The value argument must be within the support')
254
255 def _get_checked_instance(self, cls, _instance=None):
ValueError: Error while computing log_prob at site 'obs':
The value argument must be within the support
Trace Shapes:
Param Sites:
Sample Sites:
beta dist 19 |
value 19 |
log_prob 19 |
lambda dist 20 |
value 20 |
log_prob 20 |
z dist 320 |
value 320 |
log_prob 320 |
obs dist 320 |
value 320 |