Enumerating Factorial HMM with scan errors in MCMC and SVI

Hi everyone! Thanks for the great effort, I am excited to have TraceEnum_ELBO in numpyro now.

I am trying to implement a factorial HMM with two discrete latent chains and to leverage scan.

Bellow is the model and here is the notebook with the minimal example showing the error traces.

@config_enumerate
def model_scan(prior, obs=None, n=NT):
    nt = n if obs is None else len(obs)

    with numpyro.plate("s_space", NS, dim=-1):
        mu = numpyro.sample("mu", dist.Normal(prior["mu_loc"], prior["mu_scale"]))

    with numpyro.plate("z_space", NZ, dim=-1):
        sigma = numpyro.sample("sigma", dist.Exponential(prior["sigma_loc"]))

    trans_s = numpyro.sample(
        "trans_s", dist.Dirichlet(prior["trans_s"] * prior["conc"]).to_event(1)
    )
    trans_z = numpyro.sample(
        "trans_z", dist.Dirichlet(prior["trans_z"] * prior["conc"]).to_event(1)
    )
    s_m1 = numpyro.sample(
        "s_-1",
        dist.Categorical(prior["init_s"]),
    )
    z_m1 = numpyro.sample(
        "z_-1",
        dist.Categorical(prior["init_z"]),
    )
    init_state = (s_m1, z_m1)

    def transition(state, y_obs):
        s_prev, z_prev = state
        y = numpyro.sample("y", dist.Normal(mu[s_prev], sigma[z_prev]), obs=y_obs)
        s = numpyro.sample("s", dist.Categorical(trans_s[s_prev]))
        z = numpyro.sample("z", dist.Categorical(trans_z[z_prev]))
        return (s, z), (s, z, y)

    _, (s, z, y) = scan(transition, init_state, obs, length=nt)
    return (s, z, y)

Unfortunately, this naive attempt appears to put funsor off and result in a TypeError: reshape total size must be unchanged, got new_sizes (49, 3) for shape (49, 3, 2, 1, 1, 1, 1)., regardless of whether I try NUTS or TraceEnum_ELBO.

It looks like I may be missing something simple regarding indexing with enumerated variables, presumably around y = numpyro.sample("y", dist.Normal(mu[s_prev], sigma[z_prev]), obs=y_obs), but I have not been able to trace the shapes much better, unfortunately. I am also not sure if the model may actually violate any of the restrictions for parallel enumeration.

I would appreciate any hints and tips.

I think I got a bit further tracing the issue down.
Firstly, I propose to modify an existing test, test_scan_enum_two_latents from test.contrib.test_funsor.py into:

def test_scan_enum_two_latents_indexing_different_vars():
    num_steps = 11
    data = random.normal(random.PRNGKey(0), (num_steps,))
    probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
    probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
    locs = jnp.array([-1.0, 1.0])
    sigmas = jnp.array([0.1, 1.0])

    def model(data):
        x = w = 0
        for i, y in markov(enumerate(data)):
            x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
            w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w]))
            numpyro.sample(f"y_{i}", dist.Normal(locs[w], sigmas[x]), obs=y)

    def fun_model(data):
        def transition_fn(carry, y):
            x, w = carry
            x = numpyro.sample("x", dist.Categorical(probs_x[x]))
            w = numpyro.sample("w", dist.Categorical(probs_w[w]))
            numpyro.sample("y", dist.Normal(locs[w], sigmas[x]), obs=y)
            # also test if scan's `ys` are recorded corrected
            return (x, w), x

        scan(transition_fn, (0, 0), data)

    actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[
        0
    ]
    expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
    assert_allclose(actual_log_joint, expected_log_joint)

changing numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y) to numpyro.sample("y", dist.Normal(locs[w], sigmas[x]), obs=y).

This leads to the same issue as I stumbled upon earlier. To the best of my understanding, funsor looses a dimension in dim_to_name while applying numpyro.contrib.funsor.enum_messenger.trace in scan numpyro.contrib.control_flow.scan.py:L464

  /data/soft/numpyro/test/contrib/test_funsor.py(466)test_scan_enum_two_latents_indexing_different_vars()
    465 
--> 466     actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[
    467         0

  /data/soft/numpyro/numpyro/contrib/funsor/infer_util.py(318)log_density()
    317     """
--> 318     result, model_trace, _ = _enum_log_density(
    319         model,

  /data/soft/numpyro/numpyro/contrib/funsor/infer_util.py(201)_enum_log_density()
    200     with plate_to_enum_plate():
--> 201         model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
    202     log_factors = []

If I pepper trace.postprocess_message() in numpyro.contrib.funsor.enum_messenger.py with a few print statements as such

    def postprocess_message(self, msg):
        if msg["type"] == "sample":
            total_batch_shape = lax.broadcast_shapes(
                tuple(msg["fn"].batch_shape),
                jnp.shape(msg["value"])[: jnp.ndim(msg["value"]) - msg["fn"].event_dim],
            )
            if dim_to_name := msg["infer"].get("dim_to_name"):
                print(f"\n{msg['name']=}\t{msg['value'].shape=}")
                dim_to_name_sorted = OrderedDict(
                    [(k, dim_to_name[k]) for k in sorted(dim_to_name)]
                )
                print(f"Started with {dim_to_name_sorted}")

            msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(
                total_batch_shape
            )

            if dim_to_name:
                dim_to_name_sorted = OrderedDict(
                    [(k, dim_to_name[k]) for k in sorted(msg["infer"]["dim_to_name"])]
                )
                print(f"Ended up with {dim_to_name_sorted}")
                if len(dim_to_name) > len(msg["infer"]["dim_to_name"]):
                    __import__("ipdb").set_trace()

            msg["infer"]["name_to_dim"] = {
                name: dim for dim, name in msg["infer"]["dim_to_name"].items()
            }
        if msg["type"] in ("sample", "param"):
            super().postprocess_message(msg)

this is what I get:

msg['name']='x'	msg['value'].shape=(10, 2, 1, 1)
Started with OrderedDict([(-4, '_time_x'), (-3, 'x'), (-1, '_PREV_x')])
Ended up with OrderedDict([(-4, '_time_x'), (-3, 'x'), (-1, '_PREV_x')])

msg['name']='w'	msg['value'].shape=(10, 2, 1, 1, 1)
Started with OrderedDict([(-5, '_time_x'), (-4, 'w'), (-2, '_PREV_w')])
Ended up with OrderedDict([(-5, '_time_x'), (-4, 'w'), (-2, '_PREV_w')])

msg['name']='y'	msg['value'].shape=(10, 1, 1, 1, 1)
Started with OrderedDict([(-5, '_time_x'), (-4, 'w'), (-3, 'x')])
Ended up with OrderedDict([(-5, '_time_x'), (-4, 'w')])

While I don’t know if it is a bug or a feature, a small change of

msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(
    total_batch_shape
)

to

msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(
    total_batch_shape, dim_to_name=msg["infer"].get("dim_to_name")
)

seems to fix the issue and results in a well functioning MCMC inference.

@fehiepsi, since it looks like you added the original test, would you mind having a look? If it helps / makes sense, I can open an issue / PR.