I think I got a bit further tracing the issue down.
Firstly, I propose to modify an existing test, test_scan_enum_two_latents
from test.contrib.test_funsor.py
into:
def test_scan_enum_two_latents_indexing_different_vars():
num_steps = 11
data = random.normal(random.PRNGKey(0), (num_steps,))
probs_x = jnp.array([[0.8, 0.2], [0.1, 0.9]])
probs_w = jnp.array([[0.7, 0.3], [0.6, 0.4]])
locs = jnp.array([-1.0, 1.0])
sigmas = jnp.array([0.1, 1.0])
def model(data):
x = w = 0
for i, y in markov(enumerate(data)):
x = numpyro.sample(f"x_{i}", dist.Categorical(probs_x[x]))
w = numpyro.sample(f"w_{i}", dist.Categorical(probs_w[w]))
numpyro.sample(f"y_{i}", dist.Normal(locs[w], sigmas[x]), obs=y)
def fun_model(data):
def transition_fn(carry, y):
x, w = carry
x = numpyro.sample("x", dist.Categorical(probs_x[x]))
w = numpyro.sample("w", dist.Categorical(probs_w[w]))
numpyro.sample("y", dist.Normal(locs[w], sigmas[x]), obs=y)
# also test if scan's `ys` are recorded corrected
return (x, w), x
scan(transition_fn, (0, 0), data)
actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[
0
]
expected_log_joint = log_density(enum(config_enumerate(model)), (data,), {}, {})[0]
assert_allclose(actual_log_joint, expected_log_joint)
changing numpyro.sample("y", dist.Normal(locs[w, x], 1), obs=y)
to numpyro.sample("y", dist.Normal(locs[w], sigmas[x]), obs=y)
.
This leads to the same issue as I stumbled upon earlier. To the best of my understanding, funsor looses a dimension in dim_to_name
while applying numpyro.contrib.funsor.enum_messenger.trace
in scan
numpyro.contrib.control_flow.scan.py:L464
/data/soft/numpyro/test/contrib/test_funsor.py(466)test_scan_enum_two_latents_indexing_different_vars()
465
--> 466 actual_log_joint = log_density(enum(config_enumerate(fun_model)), (data,), {}, {})[
467 0
/data/soft/numpyro/numpyro/contrib/funsor/infer_util.py(318)log_density()
317 """
--> 318 result, model_trace, _ = _enum_log_density(
319 model,
/data/soft/numpyro/numpyro/contrib/funsor/infer_util.py(201)_enum_log_density()
200 with plate_to_enum_plate():
--> 201 model_trace = packed_trace(model).get_trace(*model_args, **model_kwargs)
202 log_factors = []
If I pepper trace.postprocess_message()
in numpyro.contrib.funsor.enum_messenger.py
with a few print statements as such
def postprocess_message(self, msg):
if msg["type"] == "sample":
total_batch_shape = lax.broadcast_shapes(
tuple(msg["fn"].batch_shape),
jnp.shape(msg["value"])[: jnp.ndim(msg["value"]) - msg["fn"].event_dim],
)
if dim_to_name := msg["infer"].get("dim_to_name"):
print(f"\n{msg['name']=}\t{msg['value'].shape=}")
dim_to_name_sorted = OrderedDict(
[(k, dim_to_name[k]) for k in sorted(dim_to_name)]
)
print(f"Started with {dim_to_name_sorted}")
msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(
total_batch_shape
)
if dim_to_name:
dim_to_name_sorted = OrderedDict(
[(k, dim_to_name[k]) for k in sorted(msg["infer"]["dim_to_name"])]
)
print(f"Ended up with {dim_to_name_sorted}")
if len(dim_to_name) > len(msg["infer"]["dim_to_name"]):
__import__("ipdb").set_trace()
msg["infer"]["name_to_dim"] = {
name: dim for dim, name in msg["infer"]["dim_to_name"].items()
}
if msg["type"] in ("sample", "param"):
super().postprocess_message(msg)
this is what I get:
msg['name']='x' msg['value'].shape=(10, 2, 1, 1)
Started with OrderedDict([(-4, '_time_x'), (-3, 'x'), (-1, '_PREV_x')])
Ended up with OrderedDict([(-4, '_time_x'), (-3, 'x'), (-1, '_PREV_x')])
msg['name']='w' msg['value'].shape=(10, 2, 1, 1, 1)
Started with OrderedDict([(-5, '_time_x'), (-4, 'w'), (-2, '_PREV_w')])
Ended up with OrderedDict([(-5, '_time_x'), (-4, 'w'), (-2, '_PREV_w')])
msg['name']='y' msg['value'].shape=(10, 1, 1, 1, 1)
Started with OrderedDict([(-5, '_time_x'), (-4, 'w'), (-3, 'x')])
Ended up with OrderedDict([(-5, '_time_x'), (-4, 'w')])
While I don’t know if it is a bug or a feature, a small change of
msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(
total_batch_shape
)
to
msg["infer"]["dim_to_name"] = NamedMessenger._get_dim_to_name(
total_batch_shape, dim_to_name=msg["infer"].get("dim_to_name")
)
seems to fix the issue and results in a well functioning MCMC inference.
@fehiepsi, since it looks like you added the original test, would you mind having a look? If it helps / makes sense, I can open an issue / PR.