Hi @tom0, sure Pyro might be a good tool for your job.
My understanding is that switching dynamical systems usually rely on moment-matching approximations, whereas Pyro mostly relies on variational approximations. So let me propose a way you might model one of these switching systems in Pyro: what if you model the discrete states as a (learnable) HMM and model continuous state variationally (with learnable measurement noise and learnable mode-dependent dynamics)? If you do that, I think you could use Pyro’s exact inference for HMMs and use backprop (pyro.infer.SVI
) to learn the transition and measurement parameters. Here’s a simple univariate nearly-constant-position model.
def model(data, num_states=2):
assert data.dim() == 1
T = len(data)
S = num_states
# First define parameters of a discrete hidden Markov model with
# continuous observations. These observations will be the hidden
# continuous state (here velocity) of your model.
pyro.sample("init_logits", dist.Normal(torch.zeros(S), 1).to_event(1))
pyro.sample("trans_logits", dist.Normal(torch.zeros(S, S), 1).to_event(2))
state_loc = pyro.sample("state_std", dist.Normal(0, 5))
state_scale = pyro.sample("state_std", dist.LogNormal(0, 5))
state_dist = dist.Normal(velo
hmm = dist.DiscreteHMM(init_logits, transition_logits, state_dist)
state = pyro.sample("state", hmm)
# Next we'll add a model section relating that hidden state to observed state.
position = state.cumsum(0) # Apply dynamics. This could be more general.
obs_noise_scale = pyro.sample("obs_noise_scale", dist.LogNormal(0, 5))
with pyro.plate("time", T):
pyro.sample("position", dist.Normal(0, obs_noise_scale), obs=position)
# Finally return stuff for prediction.
return hmm, velocity
Then you can fit parameters using SVI
guide = AutoNormal(model, init_loc_fn=init_to_sample)
svi = SVI(model, guide, ClippedAdam({"lr": 0.01}), Trace_ELBO())
for step in range(2000):
svi.step(data)
To predict the hidden state sequence you can use the results of the model:
median = guide.median()
hmm, state = poutine.condition(model, median)(data)
hidden_dist = hmm.filter(state)
print(hidden_dist.probs)
To generate synthetic data you’ll need to write a different equivalent model that manually samples from that hmm distribution. You can follow the examples/hmm.py tutorial @martinjankowiak pointed to. Basically you can write another model that is sequential:
def model_2(T):
...
x = pyro.sample("init", dist.Categorical(logits=init_logits))
ys = []
for t in range(T):
x = pyro.sample(f"x_{t}", dist.Categorical(logits=trans_logits[state]))
y = pyro.sample(f"y_{t}", dist.Normal(state_loc[x], state_scale[x]))
ys.append(y)
state = torch.cat(ys)
...