An issue with the shape of a continous latent that depends on a enumerated discrete inside scan

Hello numpyro experts! I am trying to implement a switching linear dynamical system which would work with NUTS, but I have issues figuring out the shapes of continuous latents that depend on discrete latents. Here is a simplified model derived from the HMM example:

def model_1(sequences, hidden_dim, **kwargs):
    length, batch_dim, data_dim = sequences.shape

    probs_z = numpyro.sample(
        "probs_z", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
    )

    loc_y = param('py', jnp.zeros((hidden_dim, data_dim)))
    scale_y = param('py', jnp.ones((hidden_dim, data_dim)), constraint=dist.constraints.softplus_positive)

    def transition_fn(carry, y):
        z_prev, t = carry
        with numpyro.plate("batch", batch_dim, dim=-2):
            z = numpyro.sample(
                "z",
                dist.Categorical(probs_z[z_prev]),
                infer={"enumerate": "parallel"},
            )
            
            with plate('data_y', data_dim, dim=-1):
                loc = Vindex(loc_y)[z.squeeze(-1)]
                scale = Vindex(scale_y)[z.squeeze(-1)]
                x = numpyro.sample("x", dist.Normal(loc, scale))
                y = numpyro.sample("y", dist.Normal(loc, scale), obs=y)

        return (z, t + 1), None

    z_init = jnp.zeros((batch_dim, 1), dtype=jnp.int32)
    scan(transition_fn, (z_init, 0), sequences)

When the continuous variable is observed, like y here, everything works fine. However, when I have latent such as x, I cannot figure out what the shape of loc and scale should be so that I do not get incompatible shapes for broadcasting error. For example this error:

Incompatible shapes for broadcasting: shapes=[(16, 4, 5), (200, 4, 5)]

occurs only when the x latent is added to the model. I tried expanding loc, and scale along (-3) but that did not help. I also tried moving first axis to (-2) position, but that created a different error about nested scans attempt.

Any thoughts on why this is not working?

I guess some batch dimensions are messed up. Could you print shapes of all variables involved z, probs_z[z_prev], z_prev, loc, scale, y?

Hi @fehiepsi, thank you for looking into this. This is what I get from format_shapes

Trace Shapes:                 
 Param Sites:                 
           py       16 5      
Sample Sites:                 
 probs_z dist          | 16 16
        value          | 16 16
       z dist 200 4  1 |      
        value 200 4  1 |      
       x dist 200 4  5 |      
        value 200 4  5 |      
       y dist 200 4  5 |      
        value 200 4  5 |      

I could not infer any specific issue from this, and specially I do not know why is enumeration that messing up ‘x’ variable but works fine with ‘y’ variable which is linked to obs.

Those are shapes without enumeration. You can print shapes with enumeration by adding print statement in your model.

OK, no problem. when I run model like this

rng_key = random.PRNGKey(2)
kernel = NUTS(model_1)
mcmc = MCMC(
kernel,
num_warmup=800,
num_samples=200,
num_chains=1,
chain_method=‘vectorized’,
progress_bar=True,
)
mcmc.run(rng_key, ys, hidden_dim=16)

I get the following output

z (4, 1)
loc (4, 5)
scale (4, 5)
x (4, 5)
y (4, 5)
z (16, 1, 1)
loc (16, 1, 5)
scale (16, 1, 5)

before the shape error occurs. If I comment out the line with ‘x’ variable. I get the following shapes (and no errors)

z (4, 1)
loc (4, 5)
scale (4, 5)
y (4, 5)
z (16, 1, 1)
loc (16, 1, 5)
scale (16, 1, 5)
y (4, 5)
z (16, 1, 1, 1)
loc (16, 1, 1, 5)
scale (16, 1, 1, 5)
y (4, 5)

It looks like this model is not supported. There is a relevant discussion here Modelling mutivariate switching dynamics - advice on getting started?

Thanks, I will explore the alternatives. So, far I could make DiscreteHMCGibbs working, but it is quite slow.

Do you have more details on what aspect is specifically not supported? My understanding is that one can make quite complex dependencies between discrete variables, given all the HMM examples with enumeration. I guess that the interactions between discrete and continuous latents are not supported with enumeration.

On the second thought, I guess we can work around it as follows:

with numpyro.plate("batch", batch_dim, dim=-2):
    with numpyro.plate('data_y', data_dim, dim=-1):
        x = numpyro.sample("x", dist.Normal(0, 1).expand([length]).to_event(1).mask(False))
x = jnp.moveaxis(x, -1, 0)

def transition_fn(carry, xy):
    x, y = xy
    ...
    x = numpyro.sample("x", dist.Normal(loc, scale), obs=x)
    y = numpyro.sample("y", dist.Normal(loc, scale), obs=y)
    ...

I guess we can modify the scan implementation to support your original model. Could you make a github issue to track it?

Thanks for the suggestion @fehiepsi. I am not sure if that would be useful, and I do not want to waste everyone’s time. What I actually am trying to implement is a rSLDS defined as

\prod_t p(y_t|x_t) p(x_t|z_t, x_{t-1}) p(z_t|x_{t-1})

where y are observations, x are continuous latent states, and z are categorical latents. So, I was simply trying to understand why am I getting that error inside scan context, when I enumerate over z.

I have tried also using a mixture distribution inside scan (for p(x_t|z_t, x_{t-1}) p(z_t|x_{t-1})) but that also seems not to be supported, as I end up with a different error.

Do you think it is a worth while making a request on github for these features? I could also try fixing this myself, but would need some guidance on where to start.

I think the issue is we need to promote shapes of the continuous latent variables properly to match the enumerated dimension (which cascades between -3 and -4 during scan). Currently, that logic is handled for obs but not for condition/substitute handlers numpyro/numpyro/contrib/control_flow/scan.py at f478772b6abee06b7bfd38c11b3f832e01e089f5 · pyro-ppl/numpyro · GitHub

Because we just need to compute log density (not sampling from the model) for HMC, I guess the trick in my last comment will work.

Perfect, thanks for the pointers. I will take a look into this and then open a github issue if I get stuck. Cheers!