I would like to reproduce the example, , fitting a target distribution using HMC.
Here is the code,
import numpy as np
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive, SVI, TraceMeanField_ELBO
def model():
mu_obs = torch.tensor([0., 5.])
sig_obs = torch.tensor([[2., 0.], [0., 3.]])
samples = [pyro.sample('samples', dist.MultivariateNormal(mu_obs, sig_obs)) for _ in range(250)]
return samples
nuts_kernel = NUTS(model, jit_compile=True, step_size=1e-5)
MCMC(
nuts_kernel,
num_samples= 200,
warmup_steps= 100,
num_chains= 1,
).run()
But I got the following error message
RuntimeError: Multiple sample sites named 'samples'
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_7832/676237614.py in <module>
4 num_samples= 200,
5 warmup_steps= 100,
----> 6 num_chains= 1,
7 ).run()
------skip some ----
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\hmc.py in setup(self, warmup_steps, *args, **kwargs)
323 self._warmup_steps = warmup_steps
324 if self.model is not None:
--> 325 self._initialize_model_properties(args, kwargs)
326 if self.initial_params:
327 z = {k: v.detach() for k, v in self.initial_params.items()}
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
267 skip_jit_warnings=self._ignore_jit_warnings,
268 init_strategy=self._init_strategy,
--> 269 initial_params=self._initial_params,
270 )
271 self.potential_fn = potential_fn
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
425 automatic_transform_enabled = False
426 if max_plate_nesting is None:
--> 427 max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
428 # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
429 # No-op if model does not have any discrete latents.
~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\util.py in _guess_max_plate_nesting(model, args, kwargs)
249 """
250 with poutine.block():
--> 251 model_trace = poutine.trace(model).get_trace(*args, **kwargs)
252 sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
253
..... skip some
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in _pyro_post_sample(self, msg)
133 assert not msg["is_observed"]
134 return
--> 135 self.trace.add_node(msg["name"], **msg.copy())
136
137 def _pyro_post_param(self, msg):
~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_struct.py in add_node(self, site_name, **kwargs)
114 # Cannot sample after a previous sample statement.
115 raise RuntimeError(
--> 116 "Multiple {} sites named '{}'".format(kwargs["type"], site_name)
117 )
118
RuntimeError: Multiple sample sites named 'samples'
Trace Shapes:
Param Sites:
Sample Sites:
samples dist | 2
value | 2