Hi,
I am trying to apply an HMM model using numpyro where the outcome variables essentially follow an AR(p) process, however, the coefficients are dependent on a discrete hidden state, and some data values are missing. My current approach is very naïve (with boundary issues between transition points), though it provides an opportunity for me to understand how enumeration and broadcasting interact with a transition function through scan. I provide the model and some fake data below.
The model appears to run okay without the part that deals with the missing data, however as soon as I run it with the missing data part, I get a broadcast error.
ValueError: Incompatible shapes for broadcasting: ((3,), (400,))
At this point I am very confused how the enumerated samples are handled when using them as parameters for other sites and how the shapes of the sites are influenced by scan. Any advice on why this is happening, what I am doing wrong here, or any suggestions would be much appreciated.
Thanks in advance.
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.ops.indexing import Vindex
def model(Y_init=None,P=2,T=None, Y=None):
# of the three possible states, we assume each is equally likely.
# we apply an hmm approach, therefore assume that there is a probability
# of moving between the states and staying in a state
# the outcome is subject to the state we are in and the P previously observed
# outcome values.
hidden_dim = 3
trend_state_prob = numpyro.sample("trend_state_prob",
dist.Dirichlet(0.9*jnp.eye(hidden_dim) + 0.1).to_event(1))
sigma = numpyro.sample("sigma", dist.HalfCauchy(0.1))
# we have a coefficient for each previous time value we use, so an AR(p)
b = numpyro.sample("b", dist.TruncatedNormal(0.,0.6, low=-1, high=1).expand([hidden_dim, P]).to_event(2))
def transition(carry, y):
trend_previous, y_previous = carry
trend_state = numpyro.sample("trend_state",
dist.Categorical(trend_state_prob[trend_previous]),
infer={"enumeration":"parallel"})
coeffs = b[trend_state,...]
# start simple and ignore the constant term
mu = coeffs @ y_previous
observed = y
# commenting the following 5 lines makes the model run without issue
if Y is not None:
missing = numpyro.sample("missing", dist.Normal(mu, sigma)) #<-- broadcast error here
observed = jnp.where(jnp.isnan(y), missing, y)
else:
observed = None
y_next = numpyro.sample("y", dist.Normal(mu, sigma),obs=observed)
y_previous = jnp.concatenate([y_next[None],y_previous[:-1]])
return (trend_state, y_previous), None
x_init = 0 # jnp.zeros((1), dtype=jnp.int32)
_, _ = scan(transition, (x_init,Y_init), Y, length=T)
# generate fake data, assuming 3 latent factors
coef1 = jnp.asarray([0.5, 0.2])
coef2 = jnp.asarray([-0.75, 0.25])
coef3 = jnp.asarray([-0.35, -0.4])
# we assume we start in stage1, then move to stage 2, and finally stage 3
stage1 = np.random.normal(0,1,size=151)
stage1 = jnp.asarray(np.lib.stride_tricks.sliding_window_view(stage1,2)) @ coef1
stage2 = np.random.normal(0,1,size=151)
stage2 = jnp.asarray(np.lib.stride_tricks.sliding_window_view(stage2,2)) @ coef2
stage3 = np.random.normal(0,1,size=151)
stage3 = jnp.asarray(np.lib.stride_tricks.sliding_window_view(stage3,2)) @ coef3
data = jnp.concatenate([stage1, stage2, stage3])
# we assume up to 20% of the data is missing at random
rng_key = jax.random.PRNGKey(42)
n = data.shape[0]
#comment the next 2 lines to generate data without missing values
selection = jax.random.randint(rng_key,(int(n*0.2),), 0, n)
data = data.at[selection].set(jnp.nan)
P = 2
T = 400
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=500, num_samples=500, num_chains=1)
data_dict = dict(Y_init=data[:P], P=P, T=T, Y=data[P:][:T])
mcmc.run(jax.random.PRNGKey(0), **data_dict)
samples = mcmc.get_samples()