Question about jax.lax.scan

I’m trying to take 10 (for 10 items) sample paths (of length 50) of a markov chain with 2 states.

 with numpyro.plate("plate_state", 2 , dim=-1):

            transition_prior = jnp.ones(2)
            transition_prob = numpyro.sample("t_p", dist.Dirichlet(jnp.broadcast_to(transition_prior, (1, 2))))

            plate_items = numpyro.plate("plate_items", 10 , dim=-2)
            plate_times = numpyro.plate("time",  50 , dim=-3)

            transition_prob_reshaped = transition_prob.reshape((1,1,2,2))
            transition_prob_tiled = jnp.tile(transition_prob_reshaped,(50,10,1,1))


            with  plate_items,plate_times:
                     category_samples =numpyro.sample("c_s",dist.Categorical(transition_prob_tiled))

This yield category_samples of shape (50,10,2).


start_prob = jnp.repeat(1.0 / 2,2)
category0 = numpyro.sample("category_0",dist.Categorical(start_prob))

def state_transition(category,t):
                #print("Input ",type(category),category.shape)
                categories_indexes = category + 2 * jnp.arange(10).reshape((10))
                category_samples_at_t = category_samples[t].reshape(20)
                category = category_samples_at_t[categories_indexes]
                #print("Output ",type(category),category.shape)
                return category, category

T = jnp.arange(1,50)
category, categories = jax.lax.scan(state_transition, init=category0,xs =T)

When I try MCMC for inference, I get the result below (if I uncomment the printout lines, otherwise just the error).


Input  <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> (10, 1)
Output  <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> (10, 1)


FilteredStackTrace: TypeError: reshape total size must be unchanged, got new_sizes (20,) for shape (1, 1, 1, 1).

The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.

What’s causing the category_samples[t] to have a shape of (1,1,1,1) instead of the expected shape of (10,2)? I’d appreciate the help.

Thank you.

PS: I posted this as an issue on the github repo and removed after I realized I have to put the question here. Sorry about that.