Inference fails when using scan over observed variable

Experimenting with scan for state space models I stumbled upon something weird. It looks like inference fails when I scan over the observed variable, and succeeds if I scan over an index variable and use it to index the observed variable.

Please see the example of a linear regression model implemented using scan below. I would like to scan over both xs and ys to avoid indexing them in f, but this does not work. So if someone could could point out why my approach does not work and how to do it I would be very grateful. Thanks!
I’m using numpyro version 0.9.2.

import jax
import jax.numpy as np
import numpyro as npo
from numpyro.infer import MCMC, NUTS
from numpyro import distributions as D
from numpyro.contrib.control_flow import scan
rng_key = jax.random.PRNGKey(1234)


def sample(model, *args):
    mcmc = MCMC(
        NUTS(model),
        num_warmup=1000,
        num_samples=1000,
        num_chains=1,
        progress_bar=True,
    )
    mcmc.run(*args)
    return mcmc
    

def scan_linear_works(ys, num_forecast_steps=0):
    a = npo.sample("a", D.Normal(0, 1))
    b = npo.sample("b", D.Normal(0, 0.1))
    sigma = npo.sample("sigma", D.HalfNormal(0.1))
    xs = np.arange(ys.shape[0])
    
    def f(carry, x):
        i = carry
        with npo.plate("time", ys.shape[0], dim=-1):
            mu = npo.deterministic("mu", a + b * x)
            npo.sample("y", D.Normal(mu, sigma), obs=ys[i])

        return i+1, None
    
    scan(f, np.array([0]), xs)

def scan_linear_broken(ys, num_forecast_steps=0):
    a = npo.sample("a", D.Normal(0, 1))
    b = npo.sample("b", D.Normal(0, 0.1))
    sigma = npo.sample("sigma", D.HalfNormal(0.1))
    xs = np.arange(ys.shape[0])
    
    def f(state, y):
        i = state
        with npo.plate("time", ys.shape[0], dim=-1):
            mu = npo.deterministic("mu", a + b * xs[i])
            npo.sample("y", D.Normal(mu, sigma), obs=y)

        return i+1, None
    
    scan(f, np.array([0]), ys)

true_a = -0.75
true_b = 0.5
xx = np.arange(20)
yy = true_a + true_b*xx + jax.random.normal(rng_key, xx.shape)

# recovers true parameters
sample(scan_linear_works, rng_key, yy).print_summary()

# fails to recover true parameters
sample(scan_linear_broken, rng_key, yy).print_summary()

I guess this is a bug where the observation y inside transition_fn does not have the corresponding time dimension. To make sure that the scalar observation y will be scanned correctly under time plate, you will need

npo.sample("y", D.Normal(mu, sigma), obs=y[None])

Could you report this issue over github? We can try to do such promotion automatically for you.

Btw, I’m not sure if you need time plate inside the scan. Because mu, sigma, y are scalars (or having singleton dimensions due to using i=np.array([0]) rather than i=np.array(0)), adding plate there only creates identical values for you. I think you need something like

    def f(state, y):
        i = state
        mu = npo.deterministic("mu", a + b * xs[i])
        npo.sample("y", D.Normal(mu, sigma), obs=y)
        return i+1, None
    
    scan(f, np.array(0), ys)

This should be fixed in https://github.com/pyro-ppl/numpyro/pull/1444 But please make sure that the time plate used in scan is expected: the log_prob at y site in both scan_linear_works and scan_linear_broken will have shape (ys.shape[0], ys.shape[0]) with your code.

Thank you for your response! I have not been able to pick this project back up in a while, sorry about that. The PR you merged indeed did the trick (inference works for both models in 0.10.0), and you are absolutely correct that the code works without the plate. Since this is a 1D problem it does not belong there. Thanks again!