Scan, enumerate and how to handle broadcasting

Hi,

I am trying to apply an HMM model using numpyro where the outcome variables essentially follow an AR(p) process, however, the coefficients are dependent on a discrete hidden state, and some data values are missing. My current approach is very naïve (with boundary issues between transition points), though it provides an opportunity for me to understand how enumeration and broadcasting interact with a transition function through scan. I provide the model and some fake data below.

The model appears to run okay without the part that deals with the missing data, however as soon as I run it with the missing data part, I get a broadcast error.

ValueError: Incompatible shapes for broadcasting: ((3,), (400,))

At this point I am very confused how the enumerated samples are handled when using them as parameters for other sites and how the shapes of the sites are influenced by scan. Any advice on why this is happening, what I am doing wrong here, or any suggestions would be much appreciated.

Thanks in advance.

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.ops.indexing import Vindex


def model(Y_init=None,P=2,T=None, Y=None):


    # of the three possible states, we assume each is equally likely.
    # we apply an hmm approach, therefore assume that there is a probability
    # of moving between the states and staying in a state
    # the outcome is subject to the state we are in and the P previously observed
    # outcome values.

    hidden_dim = 3

    trend_state_prob = numpyro.sample("trend_state_prob",
                                   dist.Dirichlet(0.9*jnp.eye(hidden_dim) + 0.1).to_event(1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(0.1))



    # we have a coefficient for each previous time value we use, so an AR(p)
    b = numpyro.sample("b", dist.TruncatedNormal(0.,0.6, low=-1, high=1).expand([hidden_dim, P]).to_event(2))

    def transition(carry, y):
        trend_previous, y_previous = carry

        trend_state = numpyro.sample("trend_state",
                                    dist.Categorical(trend_state_prob[trend_previous]),
                                        infer={"enumeration":"parallel"})

        coeffs = b[trend_state,...]
        # start simple and ignore the constant term
        mu = coeffs @ y_previous

        observed = y
        # commenting the following 5 lines makes the model run without issue
        if Y is not None:
            missing = numpyro.sample("missing", dist.Normal(mu, sigma)) #<-- broadcast error here
            observed = jnp.where(jnp.isnan(y), missing, y)
        else:
            observed = None
        y_next = numpyro.sample("y", dist.Normal(mu, sigma),obs=observed)
        y_previous = jnp.concatenate([y_next[None],y_previous[:-1]])

        return (trend_state, y_previous), None

    x_init = 0 # jnp.zeros((1), dtype=jnp.int32)
    _, _ = scan(transition, (x_init,Y_init), Y, length=T)


# generate fake data, assuming 3 latent factors

coef1 = jnp.asarray([0.5, 0.2])
coef2 = jnp.asarray([-0.75, 0.25])
coef3 = jnp.asarray([-0.35, -0.4])

# we assume we start in stage1, then move to stage 2, and finally stage 3

stage1 = np.random.normal(0,1,size=151)
stage1 = jnp.asarray(np.lib.stride_tricks.sliding_window_view(stage1,2)) @ coef1

stage2 = np.random.normal(0,1,size=151)
stage2 = jnp.asarray(np.lib.stride_tricks.sliding_window_view(stage2,2)) @ coef2

stage3 = np.random.normal(0,1,size=151)
stage3 = jnp.asarray(np.lib.stride_tricks.sliding_window_view(stage3,2)) @ coef3

data = jnp.concatenate([stage1, stage2, stage3])

# we assume up to 20% of the data is missing at random
rng_key = jax.random.PRNGKey(42)
n = data.shape[0]
#comment the next 2 lines to generate data without missing values
selection = jax.random.randint(rng_key,(int(n*0.2),), 0, n)
data = data.at[selection].set(jnp.nan)

P = 2
T = 400

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1)
data_dict = dict(Y_init=data[:P], P=P, T=T, Y=data[P:][:T])
mcmc.run(jax.random.PRNGKey(0), **data_dict)
samples = mcmc.get_samples()


Hi @zeneofa, it would be easier to debug if you add some print/assert statements for the shapes in the model. Sometimes, things like b[trend_state, ...] do not behave as expected if trend_state has batch dimensions (to enumerate over the support of dist.Category(trend_state_prob[trend_previous])). Your y_previous might have additional batch dimensions because it depends on y_next while y_next depends on mu and mu depends on coefs, trend_state, which has enumerated dimension. So be careful with operators like coeffs @ y_previous - you might want to do something like (coeffs @ y_previous[..., None]).squeeze(-1) which preserves batch semantics regardless batch dimensions of y_previous. The concatenate operator at y_previous = jnp.concatenate([y_next[None],y_previous[:-1]]) might also violate batch semantics because it concatenates arrays together along the first dimension - so things will be wrong if y_next and y_previous have batch dimensions. Those are something I observed. Things will be clearer if you add print/assert statements for the shapes.