Hello numpyro experts! I am trying to implement a switching linear dynamical system which would work with NUTS, but I have issues figuring out the shapes of continuous latents that depend on discrete latents. Here is a simplified model derived from the HMM example:
def model_1(sequences, hidden_dim, **kwargs):
length, batch_dim, data_dim = sequences.shape
probs_z = numpyro.sample(
"probs_z", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
)
loc_y = param('py', jnp.zeros((hidden_dim, data_dim)))
scale_y = param('py', jnp.ones((hidden_dim, data_dim)), constraint=dist.constraints.softplus_positive)
def transition_fn(carry, y):
z_prev, t = carry
with numpyro.plate("batch", batch_dim, dim=-2):
z = numpyro.sample(
"z",
dist.Categorical(probs_z[z_prev]),
infer={"enumerate": "parallel"},
)
with plate('data_y', data_dim, dim=-1):
loc = Vindex(loc_y)[z.squeeze(-1)]
scale = Vindex(scale_y)[z.squeeze(-1)]
x = numpyro.sample("x", dist.Normal(loc, scale))
y = numpyro.sample("y", dist.Normal(loc, scale), obs=y)
return (z, t + 1), None
z_init = jnp.zeros((batch_dim, 1), dtype=jnp.int32)
scan(transition_fn, (z_init, 0), sequences)
When the continuous variable is observed, like y here, everything works fine. However, when I have latent such as x, I cannot figure out what the shape of loc and scale should be so that I do not get incompatible shapes for broadcasting error
. For example this error:
Incompatible shapes for broadcasting: shapes=[(16, 4, 5), (200, 4, 5)]
occurs only when the x latent is added to the model. I tried expanding loc, and scale along (-3) but that did not help. I also tried moving first axis to (-2) position, but that created a different error about nested scans attempt.
Any thoughts on why this is not working?