Modelling mutivariate switching dynamics - advice on getting started?

Hi,

I am pretty new to pyro and probabilistic modelling but I have lots of experience with pytorch and numerical modelling/simulations so I’m really after some advice on how best to get started with this problem, and which of the many pyro models are best suited.

I have a lot of time series data showing the oscillatory dynamics of some shape-space eigenvalues of an animal during locomotion. Typically we see long periods (~30-60 secs) of straight-line movement dominated by eigenmodes 1 and 2 followed by short reorientation periods (~2-10 seconds) where the smaller eigenmodes dominate.

I want to build a model of these dynamics in order to both generate sample trajectories and to determine how many unique behavioural states are observed.

I had thought to use a switching kalman filter to capture both the dynamics of the different behaviours and the periodic switching between them. I can’t see any reference to switching kalman filters in the documentation. Is this just described by a different name or does it just arise from a combination of KF with some markov model?

Generally speaking, am I in the right place/is pyro the right tool for this job?

Many thanks in advance.

I’d recommend starting by looking at some examples in particular this and this and this to see if you can fork and modify for your particular application.

unfortunately discrete latent variable models can be tricky and probably are not the gentlest entry to the pyro framework

Hi @tom0, sure Pyro might be a good tool for your job.

My understanding is that switching dynamical systems usually rely on moment-matching approximations, whereas Pyro mostly relies on variational approximations. So let me propose a way you might model one of these switching systems in Pyro: what if you model the discrete states as a (learnable) HMM and model continuous state variationally (with learnable measurement noise and learnable mode-dependent dynamics)? If you do that, I think you could use Pyro’s exact inference for HMMs and use backprop (pyro.infer.SVI) to learn the transition and measurement parameters. Here’s a simple univariate nearly-constant-position model.

def model(data, num_states=2):
    assert data.dim() == 1
    T = len(data)
    S = num_states

    # First define parameters of a discrete hidden Markov model with
    # continuous observations. These observations will be the hidden
    # continuous state (here velocity) of your model.
    pyro.sample("init_logits", dist.Normal(torch.zeros(S), 1).to_event(1))
    pyro.sample("trans_logits", dist.Normal(torch.zeros(S, S), 1).to_event(2))
    state_loc = pyro.sample("state_std", dist.Normal(0, 5))
    state_scale = pyro.sample("state_std", dist.LogNormal(0, 5))
    state_dist = dist.Normal(velo
    hmm = dist.DiscreteHMM(init_logits, transition_logits, state_dist)
    state = pyro.sample("state", hmm)

    # Next we'll add a model section relating that hidden state to observed state.
    position = state.cumsum(0)  # Apply dynamics. This could be more general.
    obs_noise_scale = pyro.sample("obs_noise_scale", dist.LogNormal(0, 5))
    with pyro.plate("time", T):
        pyro.sample("position", dist.Normal(0, obs_noise_scale), obs=position)

    # Finally return stuff for prediction.
    return hmm, velocity

Then you can fit parameters using SVI

guide = AutoNormal(model, init_loc_fn=init_to_sample)
svi = SVI(model, guide, ClippedAdam({"lr": 0.01}), Trace_ELBO())
for step in range(2000):
    svi.step(data)

To predict the hidden state sequence you can use the results of the model:

median = guide.median()
hmm, state = poutine.condition(model, median)(data)
hidden_dist = hmm.filter(state)
print(hidden_dist.probs)

To generate synthetic data you’ll need to write a different equivalent model that manually samples from that hmm distribution. You can follow the examples/hmm.py tutorial @martinjankowiak pointed to. Basically you can write another model that is sequential:

def model_2(T):
    ...
    x = pyro.sample("init", dist.Categorical(logits=init_logits))
    ys = []
    for t in range(T):
        x = pyro.sample(f"x_{t}", dist.Categorical(logits=trans_logits[state]))
        y = pyro.sample(f"y_{t}", dist.Normal(state_loc[x], state_scale[x]))
        ys.append(y)
    state = torch.cat(ys)
    ...

@fritzo @martinjankowiak Thank you both for your replies. I will need some time to digest all of the information but great to know there is a helpful community here.