I am using the following code for model and finding the log_density:
def model(county_idx = county, log_radon = log_radon, floor = floor_measure, J=J, N=N):
sigma_y = numpyro.sample("sigma_y", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
sigma_beta = numpyro.sample("sigma_beta", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
sigma_alpha = numpyro.sample("sigma_alpha", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0.0, 10))
mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 10))
with numpyro.plate("J", J):
alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta))
mu = alpha[county_idx] + beta[county_idx] * floor
with numpyro.plate("N", N):
numpyro.sample("obs", dist.Normal(mu, sigma_y), obs=log_radon)
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key)
from numpyro.contrib.funsor.infer_util import log_density
ld_scan, trace_scan = log_density(model = model, params = mcmc.get_samples(), model_args = (county, log_radon, floor_measure, J, N), model_kwargs = {})
I am getting the following error upon running the last cell containing the log_density code:
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:274, in log_density(model, model_args, model_kwargs, params)
253 def log_density(model, model_args, model_kwargs, params):
254 """
255 Similar to :func:`numpyro.infer.util.log_density` but works for models
256 with discrete latent variables. Internally, this uses :mod:`funsor`
(...)
272 :return: log of joint density and a corresponding model trace
273 """
--> 274 result, model_trace, _ = _enum_log_density(
275 model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
276 )
277 return result.data, model_trace
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
157 model = substitute(model, data=params)
158 with plate_to_enum_plate():
--> 159 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
160 log_factors = []
161 time_to_factors = defaultdict(list) # log prob factors
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
163 def get_trace(self, *args, **kwargs):
164 """
165 Run the wrapped callable and return the recorded trace.
166
(...)
169 :return: `OrderedDict` containing the execution trace.
170 """
--> 171 self(*args, **kwargs)
172 return self.trace
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
103 return self
104 with self:
--> 105 return self.fn(*args, **kwargs)
/Users/madhav/Desktop/Aalto/parameteric-reparameterisation/parametric-reparametrizations/quarto/notebooks/python/Radon.ipynb Cell 14 line 9
7 mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 10))
8 #
----> 9 with numpyro.plate("J", J):
10 alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
11 beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta))
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:506, in plate.__enter__(self)
500 super().__enter__() # do this first to take care of globals recycling
501 name_to_dim = (
502 OrderedDict([(self.name, self.dim)])
503 if self.dim is not None
504 else OrderedDict()
505 )
--> 506 indices = to_data(
507 self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE
508 )
509 # extract the dimension allocated by to_data to match plate's current behavior
510 self.dim, self.indices = -len(indices.shape), indices.squeeze()
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:715, in to_data(x, name_to_dim, dim_type)
702 name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim
704 initial_msg = {
705 "type": "to_data",
706 "fn": lambda x, name_to_dim, dim_type: funsor.to_data(
(...)
712 "mask": None,
713 }
--> 715 msg = apply_stack(initial_msg)
716 return msg["value"]
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/primitives.py:59, in apply_stack(msg)
55 # A Messenger that sets msg["stop"] == True also prevents application
56 # of postprocess_message by Messengers above it on the stack
57 # via the pointer variable from the process_message loop
58 for handler in _PYRO_STACK[-pointer - 1 :]:
---> 59 handler.postprocess_message(msg)
60 return msg
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:534, in plate.postprocess_message(self, msg)
532 def postprocess_message(self, msg):
533 if msg["type"] in ["to_funsor", "to_data"]:
--> 534 return super().postprocess_message(msg)
535 # NB: copied literally from original plate messenger, with self._indices is replaced
536 # by self.indices
537 if msg["type"] in ("subsample", "param") and self.dim is not None:
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:417, in GlobalNamedMessenger.postprocess_message(self, msg)
415 self._pyro_post_to_funsor(msg)
416 elif msg["type"] == "to_data":
--> 417 self._pyro_post_to_data(msg)
File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:430, in GlobalNamedMessenger._pyro_post_to_data(self, msg)
427 if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
428 for name in msg["args"][0].inputs:
429 self._saved_globals += (
--> 430 (name, _DIM_STACK.global_frame.name_to_dim[name]),
431 )
KeyError: 'J'
The block of code using mcmc.run is working fine but I am unable to figure out the error in the code of log_density. Please help me resolve the error.