Error while using log_density

I am using the following code for model and finding the log_density:

def model(county_idx = county, log_radon = log_radon, floor = floor_measure, J=J, N=N):
    sigma_y = numpyro.sample("sigma_y", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
    sigma_beta = numpyro.sample("sigma_beta", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))
    sigma_alpha = numpyro.sample("sigma_alpha", dist.LeftTruncatedDistribution(dist.Normal(0.0, 1.0)))

    mu_alpha = numpyro.sample("mu_alpha", dist.Normal(0.0, 10))
    mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 10))

    with numpyro.plate("J", J):
        alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
        beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta))

    mu = alpha[county_idx] + beta[county_idx] * floor
    
    with numpyro.plate("N", N):
        numpyro.sample("obs", dist.Normal(mu, sigma_y), obs=log_radon)    
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(rng_key)
from numpyro.contrib.funsor.infer_util import log_density
ld_scan, trace_scan = log_density(model = model, params = mcmc.get_samples(), model_args = (county, log_radon, floor_measure, J, N), model_kwargs = {})

I am getting the following error upon running the last cell containing the log_density code:

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:274, in log_density(model, model_args, model_kwargs, params)
    253 def log_density(model, model_args, model_kwargs, params):
    254     """
    255     Similar to :func:`numpyro.infer.util.log_density` but works for models
    256     with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    272     :return: log of joint density and a corresponding model trace
    273     """
--> 274     result, model_trace, _ = _enum_log_density(
    275         model, model_args, model_kwargs, params, funsor.ops.logaddexp, funsor.ops.add
    276     )
    277     return result.data, model_trace

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:159, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    157 model = substitute(model, data=params)
    158 with plate_to_enum_plate():
--> 159     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    160 log_factors = []
    161 time_to_factors = defaultdict(list)  # log prob factors

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

/Users/madhav/Desktop/Aalto/parameteric-reparameterisation/parametric-reparametrizations/quarto/notebooks/python/Radon.ipynb Cell 14 line 9
      7     mu_beta = numpyro.sample("mu_beta", dist.Normal(0.0, 10))
      8 # 
----> 9     with numpyro.plate("J", J):
     10         alpha = numpyro.sample("alpha", dist.Normal(mu_alpha, sigma_alpha))
     11         beta = numpyro.sample("beta", dist.Normal(mu_beta, sigma_beta))

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:506, in plate.__enter__(self)
    500 super().__enter__()  # do this first to take care of globals recycling
    501 name_to_dim = (
    502     OrderedDict([(self.name, self.dim)])
    503     if self.dim is not None
    504     else OrderedDict()
    505 )
--> 506 indices = to_data(
    507     self._indices, name_to_dim=name_to_dim, dim_type=DimType.VISIBLE
    508 )
    509 # extract the dimension allocated by to_data to match plate's current behavior
    510 self.dim, self.indices = -len(indices.shape), indices.squeeze()

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:715, in to_data(x, name_to_dim, dim_type)
    702 name_to_dim = OrderedDict() if name_to_dim is None else name_to_dim
    704 initial_msg = {
    705     "type": "to_data",
    706     "fn": lambda x, name_to_dim, dim_type: funsor.to_data(
   (...)
    712     "mask": None,
    713 }
--> 715 msg = apply_stack(initial_msg)
    716 return msg["value"]

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/primitives.py:59, in apply_stack(msg)
     55 # A Messenger that sets msg["stop"] == True also prevents application
     56 # of postprocess_message by Messengers above it on the stack
     57 # via the pointer variable from the process_message loop
     58 for handler in _PYRO_STACK[-pointer - 1 :]:
---> 59     handler.postprocess_message(msg)
     60 return msg

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:534, in plate.postprocess_message(self, msg)
    532 def postprocess_message(self, msg):
    533     if msg["type"] in ["to_funsor", "to_data"]:
--> 534         return super().postprocess_message(msg)
    535     # NB: copied literally from original plate messenger, with self._indices is replaced
    536     # by self.indices
    537     if msg["type"] in ("subsample", "param") and self.dim is not None:

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:417, in GlobalNamedMessenger.postprocess_message(self, msg)
    415     self._pyro_post_to_funsor(msg)
    416 elif msg["type"] == "to_data":
--> 417     self._pyro_post_to_data(msg)

File ~/Library/Caches/pypoetry/virtualenvs/parametric-reparametrizations-NozDmb-_-py3.10/lib/python3.10/site-packages/numpyro/contrib/funsor/enum_messenger.py:430, in GlobalNamedMessenger._pyro_post_to_data(self, msg)
    427 if msg["kwargs"]["dim_type"] in (DimType.GLOBAL, DimType.VISIBLE):
    428     for name in msg["args"][0].inputs:
    429         self._saved_globals += (
--> 430             (name, _DIM_STACK.global_frame.name_to_dim[name]),
    431         )

KeyError: 'J'

The block of code using mcmc.run is working fine but I am unable to figure out the error in the code of log_density. Please help me resolve the error.

Also, another question that I have is that should I pass the returned constrained parameters or the unconstrained parameters while computing the log density?

Without context I had to make up some values for your model inputs, but this seems to work fine for me if I use the log density util: numpyro.infer.util.log_density instead of the one you use, though I’ll admit a lack of familiarity with funsors.

Is this what you’re after?