RuntimeError: Multiple sample sites named 'samples' for a code segment

I would like to reproduce the example, , fitting a target distribution using HMC.

Here is the code,

import numpy as np
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import MCMC, NUTS, Predictive, SVI, TraceMeanField_ELBO
def model():
    mu_obs = torch.tensor([0., 5.])
    sig_obs = torch.tensor([[2., 0.], [0., 3.]])
    samples = [pyro.sample('samples', dist.MultivariateNormal(mu_obs, sig_obs)) for _ in range(250)]
    return samples

nuts_kernel = NUTS(model, jit_compile=True, step_size=1e-5)
MCMC(
        nuts_kernel,
        num_samples= 200,
        warmup_steps= 100,
        num_chains= 1,
    ).run()

But I got the following error message

RuntimeError: Multiple sample sites named 'samples'

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_7832/676237614.py in <module>
      4         num_samples= 200,
      5         warmup_steps= 100,
----> 6         num_chains= 1,
      7     ).run()

------skip some ----

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\hmc.py in setup(self, warmup_steps, *args, **kwargs)
    323         self._warmup_steps = warmup_steps
    324         if self.model is not None:
--> 325             self._initialize_model_properties(args, kwargs)
    326         if self.initial_params:
    327             z = {k: v.detach() for k, v in self.initial_params.items()}

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\hmc.py in _initialize_model_properties(self, model_args, model_kwargs)
    267             skip_jit_warnings=self._ignore_jit_warnings,
    268             init_strategy=self._init_strategy,
--> 269             initial_params=self._initial_params,
    270         )
    271         self.potential_fn = potential_fn

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\util.py in initialize_model(model, model_args, model_kwargs, transforms, max_plate_nesting, jit_compile, jit_options, skip_jit_warnings, num_chains, init_strategy, initial_params)
    425         automatic_transform_enabled = False
    426     if max_plate_nesting is None:
--> 427         max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
    428     # Wrap model in `poutine.enum` to enumerate over discrete latent sites.
    429     # No-op if model does not have any discrete latents.

~\AppData\Roaming\Python\Python37\site-packages\pyro\infer\mcmc\util.py in _guess_max_plate_nesting(model, args, kwargs)
    249     """
    250     with poutine.block():
--> 251         model_trace = poutine.trace(model).get_trace(*args, **kwargs)
    252     sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
    253 

..... skip some  


~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_messenger.py in _pyro_post_sample(self, msg)
    133             assert not msg["is_observed"]
    134             return
--> 135         self.trace.add_node(msg["name"], **msg.copy())
    136 
    137     def _pyro_post_param(self, msg):

~\AppData\Roaming\Python\Python37\site-packages\pyro\poutine\trace_struct.py in add_node(self, site_name, **kwargs)
    114                 # Cannot sample after a previous sample statement.
    115                 raise RuntimeError(
--> 116                     "Multiple {} sites named '{}'".format(kwargs["type"], site_name)
    117                 )
    118 

RuntimeError: Multiple sample sites named 'samples'
Trace Shapes:    
 Param Sites:    
Sample Sites:    
 samples dist | 2
        value | 2

Each pyro.sample site in a Pyro model needs to have a unique name. You could fix your model by suffixing each "samples" site with a counter:

...
samples = [pyro.sample(f'samples_{i}', dist.MultivariateNormal(mu_obs, sig_obs)) for i in range(250)]
...

@eb8680_2 , really appreciate your response, it works. But I have several follow up questions:

I found that if I set jit_compile = True, the code will just keep running without any progress; if I replace it with jit_compile = False, then it run as expected. Why I cannot set jit_compile as True, thanks.

This is helpful, thanks for sharing.