Mixture model of incident cases of infectious disease (problem with broadcasting)

–Data generation/simulation
Suppose that we observe the number of incident cases of an infectious agent over 33 timepoints (a 1X33 vector). In addition, we observe this infectious agent over 9 seasons. Then we have a dataset Y of dimensions 9X33 where each row corresponds to a season and each column corresponds to the number of incident cases observed at that timepoint.

We hypothesize that the above data set depends on parameters: sigma, gamma, N, and lamba. All parameters are fixed except lamba. Let lambs = [ 2, 1./2 ]. Then for each season, lamb is assigned the value 2 (ie lamb[1]) with probability 0.6 and the value 1./2 (ie lamb[2]) with probability 0.4.

Incident cases can be computed from this parameter set theta = (lamb,sigma,gamma,N) and then the observed incident cases are subject to noise.

Below is the code to generate this synthetic data.

        import numpy as np
        import jax.numpy as jnp
        import scipy.integrate
        
        #--SIMULATION OF DATA MATRIX Y
        #--Fixed population size (N)
        N=1000

        #--33 weeks of observation
        timepoints = 33

        #--suppose we observe 5 seasons
        seasons = 9
        
        #-model parameters

        #--parameters that control dynamics
        sigma = 1./5
        gamma = 1./10
        lambs = [2, 1./2]

        #--parameters that control surveillance
        catchment = 1./20

        #--initial conditions
        e0 = 0.01
        i0 = 0.01
        r0 = 0.00
        c0 = i0
        s0 = 1. - (e0+i0+r0+c0)

        init = np.array([s0,e0,i0,r0,c0])

        #--probability of choosing lambs[1] versus lambs[2]
        probs = np.array([0.6,0.4])

        assignments = []
        for n,season in enumerate(range(seasons)):
            #--assign lamb[1] or lamb[2] to season "season"
            assign = np.random.choice([0,1],p=probs)
            lamb = lambs[assign]

            #--form vector of parameters
            theta = np.array([lamb, sigma, gamma, N])
            
            #--integrate model
            #--ode specification
            def seir(y,t,theta):
                s,e,i,r, c = y
                lamb,sigma,gamma,N = theta

                ds = (1/N)*(-1.*lamb*s*(N*i))
                de = (1/N)*(lamb*s*(N*i) - sigma*(N*e))
                di = (1/N)*(sigma*(N*e)  - gamma*(N*i))
                dr = (1/N)*(gamma*(N*i))

                dc = (1/N)*(sigma*(N*e)) 
                return jnp.stack([ds,de,di,dr, dc])

            states = scipy.integrate.odeint( seir, y0 = init, t = np.arange(timepoints), args = (theta,)  )

            #--compute incident cases from cumulative incident cases
            incident_cases = np.append( 0, np.diff(states[:,-1]))

            #--add noise
            noisy_cases = np.random.poisson(incident_cases*N)
            noisy_cases = noisy_cases.reshape(1,-1)

            #--assignments
            assignments.append(assign)
            
            #--append to matrix
            if n==0:
                Y = noisy_cases
            else:
               Y = np.vstack([Y,noisy_cases])

i would like to fit the assumed model to the above data.
The code to fit my model in numpyro is below


        import jax
        from jax.experimental.ode import odeint
        from jax.random import PRNGKey

        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer import MCMC,NUTS

        def model(Y,N,K=2,forecast=0):
            seasons,timepoints = Y.shape

            #--model diffeq
            def seir(y,t,theta):
                s,e,i,r, c = y
                lamb,sigma,gamma,N = theta

                ds = (1/N)*(-1.*lamb*s*(N*i))
                de = (1/N)*(lamb*s*(N*i) - sigma*(N*e))
                di = (1/N)*(sigma*(N*e)  - gamma*(N*i))
                dr = (1/N)*(gamma*(N*i))

                dc = (1/N)*(sigma*(N*e)) 
                return jnp.stack([ds,de,di,dr, dc])

            sigma     = numpyro.sample("sigma"    , dist.Beta(.5,.5))
            gamma     = numpyro.sample("gamma"    , dist.Uniform(1./10,1./2))
            catchment = numpyro.sample("catchment", dist.Beta(1,100))

            mix_weights = numpyro.sample("weight" , dist.Dirichlet(jnp.array([0.5,0.5])) )
            with numpyro.plate("components",K,dim=-1):
                mix_centers    = numpyro.sample("mix_centers", dist.Gamma(2,1) )

            #--loglikelihood
            sigma2 = numpyro.sample("sigma2", dist.HalfCauchy(1.) )

            #--initial conditions
            e0 = numpyro.sample("e0", dist.Beta(1,100))
            i0 = numpyro.sample("i0", dist.Beta(1,100))
            r0 = numpyro.deterministic("r0",0.*e0)
            s0 = numpyro.deterministic("s0",(1.-(e0+i0+r0)))
            c0 = numpyro.deterministic("c0",i0)

            init  = jnp.array([s0,e0,i0,r0, c0])

            mask = ~jnp.isnan(Y) #--incase there are missing values
            with numpyro.plate("seasons", seasons, dim=-1):
                assignment = numpyro.sample("assignment", dist.Categorical(mix_weights))
                lambs = mix_centers[assignment]

                thetas = jnp.hstack([lambs.reshape(-1,1)
                                , jnp.repeat(sigma,seasons).reshape(-1,1)
                                , jnp.repeat(gamma,seasons).reshape(-1,1)
                                , jnp.repeat(jnp.array([N]),seasons).reshape(-1,1)]   )
                
                states = jax.vmap( lambda theta: odeint( seir
                                                         , init
                                                         , jnp.arange(0.,timepoints)
                                                         , theta
                                                         , rtol=1e-6
                                                         , atol=1e-5
                                                         , mxstep=1000) )(thetas)

                inc_cases = jnp.clip( jnp.append( jnp.repeat(i0,seasons).reshape(-1,1), jnp.diff(states[:,:,-1]),1 ),0,jnp.inf)
                inc_cases = numpyro.deterministic("inc_cases",N*inc_cases)

                inc_cases = inc_cases
                obs_cases = inc_cases*catchment

                with numpyro.handlers.mask(mask=mask):
                    numpyro.sample("ll", dist.Normal( obs_cases, sigma2), obs = Y )

        mcmc = MCMC(
            NUTS(model, dense_mass=False)
                      , num_warmup  = 8000
                      , num_samples = 1000
                      , num_chains  = 1
                      , thinning    = 5
        )

        with numpyro.handlers.seed(rng_seed=1):
            trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
        print(numpyro.util.format_shapes(trace))

        mcmc.run(PRNGKey(20200320)                       #--seed
                 , Y = Y
                 , N = N
                 , forecast = 4
        )
        mcmc.print_summary()

        samples = mcmc.get_samples()

The error message that i cannot solve is

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:148, in broadcast_shapes(*shapes)
    147 try:
--> 148   return _broadcast_shapes_cached(*shapes)
    149 except:

File /usr/local/lib/python3.10/site-packages/jax/_src/util.py:263, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    262 else:
--> 263   return cached(config._trace_context(), *args, **kwargs)

File /usr/local/lib/python3.10/site-packages/jax/_src/util.py:256, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    254 @functools.lru_cache(max_size)
    255 def cached(_, *args, **kwargs):
--> 256   return f(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:154, in _broadcast_shapes_cached(*shapes)
    152 @cache()
    153 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 154   return _broadcast_shapes_uncached(*shapes)

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:170, in _broadcast_shapes_uncached(*shapes)
    169 if result_shape is None:
--> 170   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    171 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(9,), (9, 33)]

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[9], line 2
      1 with numpyro.handlers.seed(rng_seed=1):
----> 2     trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
      3 print(numpyro.util.format_shapes(trace))
      5 mcmc.run(PRNGKey(20200320)                       #--seed
      6          , Y = Y
      7          , N = N
      8          , forecast = 4
      9 )

File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[8], line 70, in model(Y, N, K, forecast)
     67 obs_cases = inc_cases*catchment
     69 with numpyro.handlers.mask(mask=mask):
---> 70     numpyro.sample("ll", dist.Normal( obs_cases, sigma2), obs = Y )

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     "type": "sample",
    209     "name": name,
   (...)
    218     "infer": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg["value"]

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the "stop" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get("stop"):

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:546, in plate.process_message(self, msg)
    544 overlap_idx = max(len(expected_shape) - len(dist_batch_shape), 0)
    545 trailing_shape = expected_shape[overlap_idx:]
--> 546 broadcast_shape = lax.broadcast_shapes(
    547     trailing_shape, tuple(dist_batch_shape)
    548 )
    549 batch_shape = expected_shape[:overlap_idx] + broadcast_shape
    550 msg["fn"] = msg["fn"].expand(batch_shape)

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:150, in broadcast_shapes(*shapes)
    148   return _broadcast_shapes_cached(*shapes)
    149 except:
--> 150   return _broadcast_shapes_uncached(*shapes)

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:170, in _broadcast_shapes_uncached(*shapes)
    168 result_shape = _try_broadcast_shapes(shape_list)
    169 if result_shape is None:
--> 170   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    171 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(9,), (9, 33)]

I think you need one more plate for timepoints.

Your observation Y has two dimensions (with sizes 9 and 33). With numpyro.plate("seasons", seasons, dim=-1): you specify that there is a conditionally independent batch at dim=-1. Note that plate dims are counted from the right. So you have to fix it to dim=-2.

Next you need another plate for timepoints (assuming it is also cond. independent) at dim=-1:

with numpyro.plate("timepoints", timepoints, dim=-1):
    with numpyro.handlers.mask(mask=mask):
          numpyro.sample("ll", dist.Normal( obs_cases, sigma2), obs = Y )

Finally, make sure that shapes of mask, Y, and dist.Normal( obs_cases, sigma2) can be broadcasted with each other.

More in-depth tutorial on shapes in Pyro: Tensor shapes in Pyro — Pyro Tutorials 1.8.6 documentation

Thanks for the help @ordabayev. i think we’re closer to solving this. However, im still a bit confused about how the dimension of lambda is treated. The updated code as you recommended is below.

The error message i receive replies “Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 1 for shapes (2, 1), (9, 1), (9, 1), (9, 1).”

The entire error message (below) refers to how i specify my vector of parameters (theta).
What is interesting is that when i print the shape of lambda the size that prints is (9,1) then (9,1) and finally (2,1,1). i dont understand whats happening here?

Error message:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[24], line 346
    343     trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
    344 print(numpyro.util.format_shapes(trace))
--> 346 mcmc.run(PRNGKey(20200320)                       #--seed
    347          , Y = Y
    348          , N = N
    349          , forecast = 4
    350 )
    351 mcmc.print_summary()
    353 samples = mcmc.get_samples()

File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    626 map_args = (rng_key, init_state, init_params)
    627 if self.num_chains == 1:
--> 628     states_flat, last_state = partial_map_fn(map_args)
    629     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    630 else:

File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    408 # Check if _sample_fn is None, then we need to initialize the sampler.
    409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 410     new_init_state = self.sampler.init(
    411         rng_key,
    412         self.num_warmup,
    413         init_params,
    414         model_args=args,
    415         model_kwargs=kwargs,
    416     )
    417     init_state = new_init_state if init_state is None else init_state
    418 sample_fn, postprocess_fn = self._get_cached_fns()

File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    708 # vectorized
    709 else:
    710     rng_key, rng_key_init_model = jnp.swapaxes(
    711         vmap(random.split)(rng_key), 0, 1
    712     )
--> 713 init_params = self._init_state(
    714     rng_key_init_model, model_args, model_kwargs, init_params
    715 )
    716 if self._potential_fn and init_params is None:
    717     raise ValueError(
    718         "Valid value of `init_params` must be provided with" " `potential_fn`."
    719     )

File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    651     if self._model is not None:
    652         (
    653             new_init_params,
    654             potential_fn,
    655             postprocess_fn,
    656             model_trace,
--> 657         ) = initialize_model(
    658             rng_key,
    659             self._model,
    660             dynamic_args=True,
    661             init_strategy=self._init_strategy,
    662             model_args=model_args,
    663             model_kwargs=model_kwargs,
    664             forward_mode_differentiation=self._forward_mode_differentiation,
    665         )
    666         if init_params is None:
    667             init_params = new_init_params

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:700, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    698     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    699 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 700 (init_params, pe, grad), is_valid = find_valid_initial_params(
    701     rng_key,
    702     substitute(
    703         model,
    704         data={
    705             k: site["value"]
    706             for k, site in model_trace.items()
    707             if site["type"] in ["plate"]
    708         },
    709     ),
    710     init_strategy=init_strategy,
    711     enum=has_enumerate_support,
    712     model_args=model_args,
    713     model_kwargs=model_kwargs,
    714     prototype_params=prototype_params,
    715     forward_mode_differentiation=forward_mode_differentiation,
    716     validate_grad=validate_grad,
    717 )
    719 if not_jax_tracer(is_valid):
    720     if device_get(~jnp.all(is_valid)):

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:437, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    435 # Handle possible vectorization
    436 if rng_key.ndim == 1:
--> 437     (init_params, pe, z_grad), is_valid = _find_valid_params(
    438         rng_key, exit_early=True
    439     )
    440 else:
    441     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:423, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    419 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
    420 if exit_early and not_jax_tracer(rng_key):
    421     # Early return if valid params found. This is only helpful for single chain,
    422     # where we can avoid compiling body_fn in while_loop.
--> 423     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    424     if not_jax_tracer(is_valid):
    425         if device_get(is_valid):

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:407, in find_valid_initial_params.<locals>.body_fn(state)
    405     z_grad = jacfwd(potential_fn)(params)
    406 else:
--> 407     pe, z_grad = value_and_grad(potential_fn)(params)
    408 z_grad_flat = ravel_pytree(z_grad)[0]
    409 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:289, in potential_energy(model, model_args, model_kwargs, params, enum)
    285 substituted_model = substitute(
    286     model, substitute_fn=partial(_unconstrain_reparam, params)
    287 )
    288 # no param is needed for log_density computation because we already substitute
--> 289 log_joint, model_trace = log_density_(
    290     substituted_model, model_args, model_kwargs, {}
    291 )
    292 return -log_joint

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:318, in log_density(model, model_args, model_kwargs, params)
    297 def log_density(model, model_args, model_kwargs, params):
    298     """
    299     Similar to :func:`numpyro.infer.util.log_density` but works for models
    300     with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    316     :return: log of joint density and a corresponding model trace
    317     """
--> 318     result, model_trace, _ = _enum_log_density(
    319         model,
    320         model_args,
    321         model_kwargs,
    322         params,
    323         funsor.ops.logaddexp,
    324         funsor.ops.add,
    325     )
    326     return result.data, model_trace

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:201, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    199 model = substitute(model, data=params)
    200 with plate_to_enum_plate():
--> 201     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    202 log_factors = []
    203 time_to_factors = defaultdict(list)  # log prob factors

File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (4 times)]

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[24], line 308, in model(Y, N, K, forecast)
    305 print("N")
    306 print(jnp.repeat(jnp.array([N]),seasons).reshape(-1,1).shape)
--> 308 thetas = jnp.hstack([lambs.reshape(-1,1)
    309                 , jnp.repeat(sigma,seasons).reshape(-1,1)
    310                 , jnp.repeat(gamma,seasons).reshape(-1,1)
    311                 , jnp.repeat(jnp.array([N]),seasons).reshape(-1,1)]   )
    313 states = jax.vmap( lambda theta: odeint( seir
    314                                          , init
    315                                          , jnp.arange(0.,timepoints)
   (...)
    318                                          , atol=1e-5
    319                                          , mxstep=1000) )(thetas)
    321 inc_cases = jnp.clip( jnp.append( jnp.repeat(i0,seasons).reshape(-1,1), jnp.diff(states[:,:,-1]),1 ),0,jnp.inf)

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1879, in hstack(tup, dtype)
   1877   arrs = [atleast_1d(m) for m in tup]
   1878   arr0_ndim = arrs[0].ndim
-> 1879 return concatenate(arrs, axis=0 if arr0_ndim == 1 else 1, dtype=dtype)

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1854, in concatenate(arrays, axis, dtype)
   1852 k = 16
   1853 while len(arrays_out) > 1:
-> 1854   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1855                 for i in range(0, len(arrays_out), k)]
   1856 return arrays_out[0]

File /usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:1854, in <listcomp>(.0)
   1852 k = 16
   1853 while len(arrays_out) > 1:
-> 1854   arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
   1855                 for i in range(0, len(arrays_out), k)]
   1856 return arrays_out[0]

    [... skipping hidden 30 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:3098, in _concatenate_shape_rule(*operands, **kwargs)
   3094   msg = ("Cannot concatenate arrays with shapes that differ in dimensions "
   3095          "other than the one being concatenated: concatenating along "
   3096          "dimension {} for shapes {}.")
   3097   shapes = [operand.shape for operand in operands]
-> 3098   raise TypeError(msg.format(dimension, ", ".join(map(str, shapes))))
   3100 concat_size = sum(o.shape[dimension] for o in operands)
   3101 ex_shape = operands[0].shape

TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 1 for shapes (2, 1), (9, 1), (9, 1), (9, 1).
        import jax
        from jax.experimental.ode import odeint
        from jax.random import PRNGKey

        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer import MCMC,NUTS

        def model(Y,N,K=2,forecast=0):
            seasons,timepoints = Y.shape

            #--model diffeq
            def seir(y,t,theta):
                s,e,i,r, c = y
                lamb,sigma,gamma,N = theta

                ds = (1/N)*(-1.*lamb*s*(N*i))
                de = (1/N)*(lamb*s*(N*i) - sigma*(N*e))
                di = (1/N)*(sigma*(N*e)  - gamma*(N*i))
                dr = (1/N)*(gamma*(N*i))

                dc = (1/N)*(sigma*(N*e)) 
                return jnp.stack([ds,de,di,dr, dc])

            sigma     = numpyro.sample("sigma"    , dist.Beta(.5,.5))
            gamma     = numpyro.sample("gamma"    , dist.Uniform(1./10,1./2))
            catchment = numpyro.sample("catchment", dist.Beta(1,100))

            mix_weights = numpyro.sample("weight" , dist.Dirichlet(jnp.array([0.5,0.5])) )
            with numpyro.plate("components",K,dim=-1):
                mix_centers    = numpyro.sample("mix_centers", dist.Gamma(2,1) )

            #--loglikelihood
            sigma2 = numpyro.sample("sigma2", dist.HalfCauchy(1.) )

            #--initial conditions
            e0 = numpyro.sample("e0", dist.Beta(1,100))
            i0 = numpyro.sample("i0", dist.Beta(1,100))
            r0 = numpyro.deterministic("r0",0.*e0)
            s0 = numpyro.deterministic("s0",(1.-(e0+i0+r0)))
            c0 = numpyro.deterministic("c0",i0)

            init  = jnp.array([s0,e0,i0,r0, c0])

            mask = ~jnp.isnan(Y) #--incase there are missing values
            with numpyro.plate("seasons", seasons, dim=-2):
                assignment = numpyro.sample("assignment", dist.Categorical(mix_weights))
                lambs = numpyro.deterministic("lambs", mix_centers[assignment])

                print("Lambda")
                print(lambs.shape)

                print("sigma")
                print(jnp.repeat(sigma,seasons).reshape(-1,1).shape)

                print("gamma")
                print(jnp.repeat(gamma,seasons).reshape(-1,1).shape)

                print("N")
                print(jnp.repeat(jnp.array([N]),seasons).reshape(-1,1).shape)
                
                thetas = jnp.hstack([lambs.reshape(-1,1)
                                , jnp.repeat(sigma,seasons).reshape(-1,1)
                                , jnp.repeat(gamma,seasons).reshape(-1,1)
                                , jnp.repeat(jnp.array([N]),seasons).reshape(-1,1)]   )
                
                states = jax.vmap( lambda theta: odeint( seir
                                                         , init
                                                         , jnp.arange(0.,timepoints)
                                                         , theta
                                                         , rtol=1e-6
                                                         , atol=1e-5
                                                         , mxstep=1000) )(thetas)

                inc_cases = jnp.clip( jnp.append( jnp.repeat(i0,seasons).reshape(-1,1), jnp.diff(states[:,:,-1]),1 ),0,jnp.inf)
                inc_cases = numpyro.deterministic("inc_cases",N*inc_cases)

                inc_cases = inc_cases
                obs_cases = inc_cases*catchment

                print(obs_cases.shape)
                print(Y.shape)
                
                with numpyro.plate("timepoints", timepoints, dim=-1):
                    with numpyro.handlers.mask(mask=mask):
                        numpyro.sample("ll", dist.Normal( obs_cases, sigma2), obs = Y )

        mcmc = MCMC(
            NUTS(model, dense_mass=False)
                      , num_warmup  = 8000
                      , num_samples = 1000
                      , num_chains  = 1
                      , thinning    = 5
        )

        with numpyro.handlers.seed(rng_seed=1):
            trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
        print(numpyro.util.format_shapes(trace))

        mcmc.run(PRNGKey(20200320)                       #--seed
                 , Y = Y
                 , N = N
                 , forecast = 4
        )
        mcmc.print_summary()
        samples = mcmc.get_samples()

i should also add the output from “format_shapes”
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
gamma dist |
value |
catchment dist |
value |
weight dist | 2
value | 2
components plate 2 |
mix_centers dist 2 |
value 2 |
sigma2 dist |
value |
e0 dist |
value |
i0 dist |
value |
seasons plate 9 |
assignment dist 9 1 |
value 9 1 |
timepoints plate 33 |
ll dist 9 33 |
value 9 33 |

When i print the shape for mask is it 9X33, the shape for Y it is 9X33 and the shape for obs_cases it is 9X33

The (2,1,1) is due to the enumeration. Have you looked at the enumeration tutorial: Inference with Discrete Latent Variables — Pyro Tutorials 1.8.6 documentation?

Dims -1 and -2 are reserved for two plate dimensions. Dim -3 is used for enumerating Lambda values.

Initial shapes (9,1) that you see is due to internal run of the guide and model by the guess_max_plate_nesting function.

You need to make sure that your code for thetas works both for enumerated (2,1,1) and non-enumerated (9,1) cases.

Let me know if you have more questions.

Thanks @ordabayev,

I reviewed the enumeration tutorial again.

I would have expected the enumerated case to be a lambda of size (2,9,1) and not (2,1,1).

What am I missing here?

The shape is (2,1,1) because there are only two values {0,1} for assignment and it is the same two values for all seasons. So there is no need to repeat it for each season. (Down the road in the ELBO computation algorithm enumerated values (2,1,1) and plate dim (9,1) will get broadcasted to each other before the marginalization step - but that is delayed to be as late as possible for efficiency).

This is as close as i can get to a solution, however how Numypro handles enumeration still doesnt make sense to me.

In the below example
i rewrote the code such that 2 trajectories are created (one per cluster).
i factored out all the code that does not involve the mixture aspect of this model
printed shapes

Below is the code and error

–code

        #--ATTEMPT TO FIT MODEL

        import jax
        from jax.experimental.ode import odeint
        from jax.random import PRNGKey

        import numpyro
        import numpyro.distributions as dist
        from numpyro.infer import MCMC,NUTS

        def model(Y,N,K=2,forecast=0):
            seasons,timepoints = Y.shape

            #--model diffeq
            def seir(y,t,lamb,sigma,gamma,N):
                s,e,i,r, c = y

                lamb = lamb.reshape(1,)
                
                ds = (1/N)*(-1.*lamb*s*(N*i))
                de = (1/N)*(lamb*s*(N*i) - sigma*(N*e))
                di = (1/N)*(sigma*(N*e)  - gamma*(N*i))
                dr = (1/N)*(gamma*(N*i))

                dc = (1/N)*(sigma*(N*e))
                
                return jnp.stack([ds,de,di,dr, dc])

            #--fixed model parameters for all seasons
            sigma     = numpyro.sample("sigma"    , dist.Beta(.5,.5).expand([1,]))
            gamma     = numpyro.sample("gamma"    , dist.Uniform(1./10,1./2).expand([1,]))
            catchment = numpyro.sample("catchment", dist.Beta(1,100).expand([1,]))

            #--initial conditions
            e0 = numpyro.sample("e0", dist.Beta(1,100))
            i0 = numpyro.sample("i0", dist.Beta(1,100))
            r0 = numpyro.deterministic("r0",0.*e0)
            s0 = numpyro.deterministic("s0",(1.-(e0+i0+r0)))
            c0 = numpyro.deterministic("c0",i0)

            init  = jnp.array([s0,e0,i0,r0, c0])

            #--observational error
            sigma2 = numpyro.sample("sigma2", dist.HalfCauchy(1.) )

            #--Handle Missingness
            mask = ~jnp.isnan(Y) 
            
            #-----------------------------------------------------------------------------
            #--Cluster code
            mix_weights = numpyro.sample("weight" , dist.Dirichlet(jnp.array([0.5,0.5])) )
            with numpyro.plate("components",K,dim=-3):
                mix_centers    = numpyro.sample("mix_centers", dist.Gamma(2,1) )

            seir_with_params = lambda y,t,center: seir(y,t,center, sigma=sigma, gamma= gamma, N = jnp.array([N])  )
                
            states = jax.vmap( lambda center: odeint(seir_with_params
                                                            , init
                                                            , jnp.arange(0.,timepoints)
                                                            , center
                                                            , rtol=1e-6
                                                            , atol=1e-5
                                                            , mxstep=1000) )(mix_centers)

            inc_cases = jnp.clip( jnp.insert(jnp.diff(states[...,-1]),0,i0,axis=1),0,jnp.inf)
            inc_cases = numpyro.deterministic("inc_cases",N*inc_cases)
            
            obs_cases = (inc_cases*catchment)

            print(f"obs_cases = {obs_cases.shape}")

            with numpyro.plate("seasons", seasons, dim=-2) as S:
                assignment = numpyro.sample("assignment", dist.Categorical(mix_weights))
                assignment = assignment.reshape(-1,)
                
                cases = Vindex(obs_cases)[...,assignment,:]

                with numpyro.plate("timepoints", timepoints, dim=-1):
                    with numpyro.handlers.mask(mask=mask):
                        print(f"cases shape = {cases.shape}")
                        print(f"Y shape = {Y.shape}")
                        numpyro.sample("ll", dist.Normal( cases, sigma2), obs = Y )
        #END MODEL-----------------------------------------------------------------------------------------------
                    
        mcmc = MCMC(
            NUTS(model, dense_mass=False)
                      , num_warmup  = 8000
                      , num_samples = 1000
                      , num_chains  = 1
                      , thinning    = 5
        )

        with numpyro.handlers.seed(rng_seed=1):
            trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
        print(numpyro.util.format_shapes(trace))

        mcmc.run(PRNGKey(20200320)                       #--seed
                 , Y = Y
                 , N = N
                 , forecast = 4
        )
        mcmc.print_summary()

        samples = mcmc.get_samples()

–error

<ipython-input-351-127a83fa2241>:138: FutureWarning: In a future version of pandas, a length 1 tuple will be returned when iterating over a groupby with a grouper equal to a list of length 1. Don't supply a list with a single grouper to avoid this warning.
  for num,(location,subset) in enumerate(region_cases.groupby([LOC])):
obs_cases = (2, 33)
cases shape = (9, 33)
Y shape = (9, 33)
   Trace Shapes:           
    Param Sites:           
   Sample Sites:           
      sigma dist      1 |  
           value      1 |  
      gamma dist      1 |  
           value      1 |  
  catchment dist      1 |  
           value      1 |  
         e0 dist        |  
           value        |  
         i0 dist        |  
           value        |  
     sigma2 dist        |  
           value        |  
     weight dist        | 2
           value        | 2
components plate      2 |  
mix_centers dist 2 1  1 |  
           value 2 1  1 |  
   seasons plate      9 |  
 assignment dist   9  1 |  
           value   9  1 |  
timepoints plate     33 |  
         ll dist   9 33 |  
           value   9 33 |  
obs_cases = (2, 33)
cases shape = (9, 33)
Y shape = (9, 33)
<ipython-input-351-127a83fa2241>:341: FutureWarning: Some algorithms will automatically enumerate the discrete latent site assignment of your model. In the future, enumerated sites need to be marked with `infer={'enumerate': 'parallel'}`.
  mcmc.run(PRNGKey(20200320)                       #--seed
obs_cases = (2, 33)
cases shape = (2, 33)
Y shape = (9, 33)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
    [... skipping hidden 1 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/util.py:263, in cache.<locals>.wrap.<locals>.wrapper(*args, **kwargs)
    262 else:
--> 263   return cached(config._trace_context(), *args, **kwargs)

File /usr/local/lib/python3.10/site-packages/jax/_src/util.py:256, in cache.<locals>.wrap.<locals>.cached(_, *args, **kwargs)
    254 @functools.lru_cache(max_size)
    255 def cached(_, *args, **kwargs):
--> 256   return f(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:154, in _broadcast_shapes_cached(*shapes)
    152 @cache()
    153 def _broadcast_shapes_cached(*shapes: tuple[int, ...]) -> tuple[int, ...]:
--> 154   return _broadcast_shapes_uncached(*shapes)

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:170, in _broadcast_shapes_uncached(*shapes)
    169 if result_shape is None:
--> 170   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    171 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(9, 33), (2, 33)]

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
Cell In[351], line 341
    338     trace = numpyro.handlers.trace(model).get_trace(Y=Y,N=N)
    339 print(numpyro.util.format_shapes(trace))
--> 341 mcmc.run(PRNGKey(20200320)                       #--seed
    342          , Y = Y
    343          , N = N
    344          , forecast = 4
    345 )
    346 mcmc.print_summary()
    348 samples = mcmc.get_samples()

File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:628, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    626 map_args = (rng_key, init_state, init_params)
    627 if self.num_chains == 1:
--> 628     states_flat, last_state = partial_map_fn(map_args)
    629     states = tree_map(lambda x: x[jnp.newaxis, ...], states_flat)
    630 else:

File /usr/local/lib/python3.10/site-packages/numpyro/infer/mcmc.py:410, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    408 # Check if _sample_fn is None, then we need to initialize the sampler.
    409 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 410     new_init_state = self.sampler.init(
    411         rng_key,
    412         self.num_warmup,
    413         init_params,
    414         model_args=args,
    415         model_kwargs=kwargs,
    416     )
    417     init_state = new_init_state if init_state is None else init_state
    418 sample_fn, postprocess_fn = self._get_cached_fns()

File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    708 # vectorized
    709 else:
    710     rng_key, rng_key_init_model = jnp.swapaxes(
    711         vmap(random.split)(rng_key), 0, 1
    712     )
--> 713 init_params = self._init_state(
    714     rng_key_init_model, model_args, model_kwargs, init_params
    715 )
    716 if self._potential_fn and init_params is None:
    717     raise ValueError(
    718         "Valid value of `init_params` must be provided with" " `potential_fn`."
    719     )

File /usr/local/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params)
    650 def _init_state(self, rng_key, model_args, model_kwargs, init_params):
    651     if self._model is not None:
    652         (
    653             new_init_params,
    654             potential_fn,
    655             postprocess_fn,
    656             model_trace,
--> 657         ) = initialize_model(
    658             rng_key,
    659             self._model,
    660             dynamic_args=True,
    661             init_strategy=self._init_strategy,
    662             model_args=model_args,
    663             model_kwargs=model_kwargs,
    664             forward_mode_differentiation=self._forward_mode_differentiation,
    665         )
    666         if init_params is None:
    667             init_params = new_init_params

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:700, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    698     init_strategy = _init_to_unconstrained_value(values=unconstrained_values)
    699 prototype_params = transform_fn(inv_transforms, constrained_values, invert=True)
--> 700 (init_params, pe, grad), is_valid = find_valid_initial_params(
    701     rng_key,
    702     substitute(
    703         model,
    704         data={
    705             k: site["value"]
    706             for k, site in model_trace.items()
    707             if site["type"] in ["plate"]
    708         },
    709     ),
    710     init_strategy=init_strategy,
    711     enum=has_enumerate_support,
    712     model_args=model_args,
    713     model_kwargs=model_kwargs,
    714     prototype_params=prototype_params,
    715     forward_mode_differentiation=forward_mode_differentiation,
    716     validate_grad=validate_grad,
    717 )
    719 if not_jax_tracer(is_valid):
    720     if device_get(~jnp.all(is_valid)):

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:437, in find_valid_initial_params(rng_key, model, init_strategy, enum, model_args, model_kwargs, prototype_params, forward_mode_differentiation, validate_grad)
    435 # Handle possible vectorization
    436 if rng_key.ndim == 1:
--> 437     (init_params, pe, z_grad), is_valid = _find_valid_params(
    438         rng_key, exit_early=True
    439     )
    440 else:
    441     (init_params, pe, z_grad), is_valid = lax.map(_find_valid_params, rng_key)

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:423, in find_valid_initial_params.<locals>._find_valid_params(rng_key, exit_early)
    419 init_state = (0, rng_key, (prototype_params, 0.0, prototype_grads), False)
    420 if exit_early and not_jax_tracer(rng_key):
    421     # Early return if valid params found. This is only helpful for single chain,
    422     # where we can avoid compiling body_fn in while_loop.
--> 423     _, _, (init_params, pe, z_grad), is_valid = init_state = body_fn(init_state)
    424     if not_jax_tracer(is_valid):
    425         if device_get(is_valid):

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:407, in find_valid_initial_params.<locals>.body_fn(state)
    405     z_grad = jacfwd(potential_fn)(params)
    406 else:
--> 407     pe, z_grad = value_and_grad(potential_fn)(params)
    408 z_grad_flat = ravel_pytree(z_grad)[0]
    409 is_valid = jnp.isfinite(pe) & jnp.all(jnp.isfinite(z_grad_flat))

    [... skipping hidden 8 frame]

File /usr/local/lib/python3.10/site-packages/numpyro/infer/util.py:289, in potential_energy(model, model_args, model_kwargs, params, enum)
    285 substituted_model = substitute(
    286     model, substitute_fn=partial(_unconstrain_reparam, params)
    287 )
    288 # no param is needed for log_density computation because we already substitute
--> 289 log_joint, model_trace = log_density_(
    290     substituted_model, model_args, model_kwargs, {}
    291 )
    292 return -log_joint

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:318, in log_density(model, model_args, model_kwargs, params)
    297 def log_density(model, model_args, model_kwargs, params):
    298     """
    299     Similar to :func:`numpyro.infer.util.log_density` but works for models
    300     with discrete latent variables. Internally, this uses :mod:`funsor`
   (...)
    316     :return: log of joint density and a corresponding model trace
    317     """
--> 318     result, model_trace, _ = _enum_log_density(
    319         model,
    320         model_args,
    321         model_kwargs,
    322         params,
    323         funsor.ops.logaddexp,
    324         funsor.ops.add,
    325     )
    326     return result.data, model_trace

File /usr/local/lib/python3.10/site-packages/numpyro/contrib/funsor/infer_util.py:201, in _enum_log_density(model, model_args, model_kwargs, params, sum_op, prod_op)
    199 model = substitute(model, data=params)
    200 with plate_to_enum_plate():
--> 201     model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    202 log_factors = []
    203 time_to_factors = defaultdict(list)  # log prob factors

File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs)
    163 def get_trace(self, *args, **kwargs):
    164     """
    165     Run the wrapped callable and return the recorded trace.
    166 
   (...)
    169     :return: `OrderedDict` containing the execution trace.
    170     """
--> 171     self(*args, **kwargs)
    172     return self.trace

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

    [... skipping similar frames: Messenger.__call__ at line 105 (4 times)]

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.__call__(self, *args, **kwargs)
    103     return self
    104 with self:
--> 105     return self.fn(*args, **kwargs)

Cell In[351], line 326, in model(Y, N, K, forecast)
    324 print(f"cases shape = {cases.shape}")
    325 print(f"Y shape = {Y.shape}")
--> 326 numpyro.sample("ll", dist.Normal( cases, sigma2), obs = Y )

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:222, in sample(name, fn, obs, rng_key, sample_shape, infer, obs_mask)
    207 initial_msg = {
    208     "type": "sample",
    209     "name": name,
   (...)
    218     "infer": {} if infer is None else infer,
    219 }
    221 # ...and use apply_stack to send it to the Messengers
--> 222 msg = apply_stack(initial_msg)
    223 return msg["value"]

File /usr/local/lib/python3.10/site-packages/numpyro/primitives.py:47, in apply_stack(msg)
     45 pointer = 0
     46 for pointer, handler in enumerate(reversed(_PYRO_STACK)):
---> 47     handler.process_message(msg)
     48     # When a Messenger sets the "stop" field of a message,
     49     # it prevents any Messengers above it on the stack from being applied.
     50     if msg.get("stop"):

File /usr/local/lib/python3.10/site-packages/numpyro/handlers.py:525, in mask.process_message(self, msg)
    520         msg["mask"] = (
    521             self.mask if msg["mask"] is None else (self.mask & msg["mask"])
    522         )
    523     return
--> 525 msg["fn"] = msg["fn"].mask(self.mask)

File /usr/local/lib/python3.10/site-packages/numpyro/distributions/distribution.py:406, in Distribution.mask(self, mask)
    404 if mask is True:
    405     return self
--> 406 return MaskedDistribution(self, mask)

File /usr/local/lib/python3.10/site-packages/numpyro/distributions/distribution.py:99, in DistributionMeta.__call__(cls, *args, **kwargs)
     97     if result is not None:
     98         return result
---> 99 return super().__call__(*args, **kwargs)

File /usr/local/lib/python3.10/site-packages/numpyro/distributions/distribution.py:842, in MaskedDistribution.__init__(self, base_dist, mask)
    840     self._mask = mask
    841 else:
--> 842     batch_shape = lax.broadcast_shapes(
    843         jnp.shape(mask), tuple(base_dist.batch_shape)
    844     )
    845     if mask.shape != batch_shape:
    846         mask = jnp.broadcast_to(mask, batch_shape)

    [... skipping hidden 1 frame]

File /usr/local/lib/python3.10/site-packages/jax/_src/lax/lax.py:170, in _broadcast_shapes_uncached(*shapes)
    168 result_shape = _try_broadcast_shapes(shape_list)
    169 if result_shape is None:
--> 170   raise ValueError(f"Incompatible shapes for broadcasting: shapes={list(shapes)}")
    171 return result_shape

ValueError: Incompatible shapes for broadcasting: shapes=[(9, 33), (2, 33)]

I’m not familiar with jax’s odeint and it is not very clear what are the constraints on seir function here (“func: function to evaluate the time derivative of the solution “y” at time
t as func(y, t, *args), producing the same shape/structure as y0.” from odeint docs) but I know that for NumPyro to work with enumeration you need to make sure that your code is broadcastable from the left, i.e. it should work with thetas that have both shapes (9,1) and (2,1,1).

Imagine you want to calculate the marginalized distribution of y_i: \prod_ip(y_i) = \sum_x\prod_ip(x_i)p(y_i|x_i) where i \in \{1,2,3\} and let’s say p(y_i|x_i) is a normal distribution N(y_i | \mu=x_i, 1) and x_i \in \{0, 1\} (binary).

Let’s also say y = torch.tensor([0.5, 1., -0.3]), so y.shape = (3,). To marginalize x we will enumerate it and use dim=-2 because dim=-1 is already taken by our plate.

x = torch.tensor([[0], [1]]) so that x.shape = (2,1). And we have p(x_i): probs_x = torch.tensor([[0.3], [0.7]]).
We can calculate p(y_i|x_i) as follows:

dist_y = dist.Normal(x, 1)  # batch_shape = (2,1)
prob_y_x = dist_y.log_prob(y).exp()  # shape = (2,3) -> broadcasted enumerated x values and our plate

# final calculation
prob_y = (prob_x * prob_y_x).sum(-2)  # marginalized x here
result = prob_y.prod()  # contracted plate

This is roughly how enumeration is used by NumPyro to compute log densities.