 # Size of states changing with state space model and scan

I am attempting to model timeseries data from a physical system with a number of operating states; in each operating state, the evolution of the latent states will follow a different structure. We’re taking a markov switching approach, based loosely off the HMM enumeration tutorial (although we have continuous latent state, in comparison to the categorical y in the tutorial).

We seem to be getting an issue with dimensions being added to the state in the scan `carry` with each draw/iteration, as we see errors of this type:
`TypeError: reshape total size must be unchanged, got new_sizes (2,) for shape (2, 2).`

As far as we can see, each iteration seems to increase the dimension (and length?) of the elements we carry in the scan.

Any suggestions for how to handle this error?

## A toy example which generates the error above

We have a system in which the value of y does not change in state 0, and in state 1 the value of y is half of the previous value. Our state switches from 0 to 1 after 40 observations:

``````# pip install numpyro funsor
import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.diagnostics import summary
from numpyro import distributions as dist
from numpyro.contrib.control_flow import scan, cond
from jax import random
import jax.numpy as jnp
import numpy as np

# create data:
y = np.ones(50)
for i in range(40,len(y)):
y[i] = y[i-1]*0.5
``````

Define the model:

``````def model(dim_s,
dim_states,
y,
meas_err,
):

with numpyro.plate("x_plate", 1):
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(0.9 * jnp.eye(dim_states) + 0.1)
)

def transition(state,t):
x_prev, latent_prev = state

# draw the state:
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
# squeeze to try to remove unnecessary dimensions?
x = x.squeeze()

# calculate evolution of latent state in each state:
latent_ = jnp.array([latent_prev, 0.5*latent_prev])

# based on state x, select update to latent state:
latent = latent_[x]

# observation equation:
y_obs = numpyro.sample("y_obs", dist.Normal(latent, meas_err), obs=y[t])

return (x, latent), (x, latent)

# uninformative initial priors:
y_init = numpyro.sample("y_init", dist.Normal(0,5))
x_init = numpyro.sample("x_init", dist.Categorical(jnp.array([0.5, 0.5])),
infer={'enumerate': 'parallel'})

_, states = scan(transition, (x_init, y_init), np.arange(dim_s))
return
``````

Define model inputs and run:

``````prior = {"meas_err": 1,
"dim_s": len(y),
"dim_states": 2,
"y": jnp.asarray(y),}

ns = 1000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=ns, num_samples=ns)
mcmc.run(random.PRNGKey(0), **prior, )
``````

A couple of thoughts:

• You might want to move `init` logic into the `transition` function (using e.g. `jnp.where(t == 0, ...)`).
• Avoid using `x.squeeze()`, which conflicts with the “broadcasting” world.
• If index broadcasting is tricky, you might want to use `Vindex` as in the hmm example.