I’m trying to take 10 (for 10 items) sample paths (of length 50) of a markov chain with 2 states.
with numpyro.plate("plate_state", 2 , dim=-1):
transition_prior = jnp.ones(2)
transition_prob = numpyro.sample("t_p", dist.Dirichlet(jnp.broadcast_to(transition_prior, (1, 2))))
plate_items = numpyro.plate("plate_items", 10 , dim=-2)
plate_times = numpyro.plate("time", 50 , dim=-3)
transition_prob_reshaped = transition_prob.reshape((1,1,2,2))
transition_prob_tiled = jnp.tile(transition_prob_reshaped,(50,10,1,1))
with plate_items,plate_times:
category_samples =numpyro.sample("c_s",dist.Categorical(transition_prob_tiled))
This yield category_samples
of shape (50,10,2)
.
start_prob = jnp.repeat(1.0 / 2,2)
category0 = numpyro.sample("category_0",dist.Categorical(start_prob))
def state_transition(category,t):
#print("Input ",type(category),category.shape)
categories_indexes = category + 2 * jnp.arange(10).reshape((10))
category_samples_at_t = category_samples[t].reshape(20)
category = category_samples_at_t[categories_indexes]
#print("Output ",type(category),category.shape)
return category, category
T = jnp.arange(1,50)
category, categories = jax.lax.scan(state_transition, init=category0,xs =T)
When I try MCMC for inference, I get the result below (if I uncomment the printout lines, otherwise just the error).
Input <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> (10, 1)
Output <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> (10, 1)
FilteredStackTrace: TypeError: reshape total size must be unchanged, got new_sizes (20,) for shape (1, 1, 1, 1).
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
What’s causing the category_samples[t]
to have a shape of (1,1,1,1)
instead of the expected shape of (10,2)
? I’d appreciate the help.
Thank you.
PS: I posted this as an issue on the github repo and removed after I realized I have to put the question here. Sorry about that.