Plating for a Mixture Model of Time Series

Like the title says, I am working on a mixture model of time series. I am starting with synthetic data as a proof of concept. Concretely, each data point is a sequence of n observations at times t=t_1, t_2, t_n (which may vary by data point) where the observation for t_i is normally distributed with mean t_i*slope[i] and scale 1, where i is the cluster ID for that data point.

I have been trying to adapt the code in the GMM tutorial for this use case, but am running into a (user) error when trying to sample in the deepest plate context. I have stripped it down as much as possible to create a MWE. Here is the code for generating the synthetic data.

    import torch
    import numpy as np

    import pyro
    import pyro.distributions as dist
    from pyro import poutine
    from pyro.infer.autoguide import AutoDiagonalNormal
    from pyro.optim import Adam
    from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, infer_discrete

    pi = .7
    slopes = [10., 2.]
    scale = 1.

    num_measurements = 50
    synth_data_and_times = []
    synth_lengths = []

    for _ in range(1000):
        choice = np.random.uniform()
        if choice > pi:
            choice = 1
        else:
            choice = 0
            
        times = np.asarray(list(range(num_measurements)))
        synth_data_and_times.append([[np.random.normal(loc=times[i]*slopes[choice], scale=scale)
                            for i, _ in enumerate(range(num_measurements))], times])
        
    synth_data_and_times = torch.tensor(synth_data_and_times)

And here is my model.

    data = synth_data_and_times[:, 0, :]
    times = synth_data_and_times[:, 1, :]
    K = 2

    @config_enumerate
    def model(data, times):
        weights = pyro.sample(
            'weights', dist.Dirichlet(0.5 * torch.ones(K)))
        
        scale = pyro.sample('scale', dist.LogNormal(0., 2.))
        with pyro.plate('components', K):
            locs = pyro.sample('locs', dist.Normal(5., 2.))
        
        with pyro.plate('data', len(data)):
            assignment = pyro.sample(f'assignment', dist.Categorical(weights))
            means = times*locs[assignment].unsqueeze(-1)
            with pyro.plate('times', len(times[0])):
                res = pyro.sample("obs", dist.Normal(means, scale), obs=data)
                
    guide = AutoDiagonalNormal(poutine.block(model, expose=['weights', 'locs', 'scale']))
    optim = pyro.optim.Adam({'lr': 0.1, 'betas': [0.8, 0.99]})
    elbo = TraceEnum_ELBO()

    svi = SVI(model, guide, optim, loss=elbo)
    svi.loss(model, guide, data, times)

The exact error message is “Shape mismatch inside plate(‘times’) at site obs dim -2, 50 vs 1000”. It is triggered by the final sampling statement in model, where I am giving means as the mean of pyro’s univariate normal distribution, even though means is a matrix. I think I have seen this before, though, both in pyro and in Stan - presumably it is just vectorized shorthand for normally distributed observations that are conditionally independent given their different means.

At the time of the final sampling statement, means has shape (1000, 50), data has shape (1000, 50), scale has shape [], and assignments has shape (1000,).

Hi @ethansargent my guess is that your data and time plate are in the wrong order. When using pyro.plate, Pyro auto-assigns the dimension (column) of each plate eagerly starting from the right, so in your code is equivalent to

with pyro.plate('data', len(data), dim=-1):
    with pyro.plate('times', len(times[0]), dim=-2):
        ...

You might try manually setting each plate’s dim arg to force the time plate to be right of the data plate, e.g.

num_series, num_times = data.shape
with pyro.plate('data', num_series, dim=-2):
    with pyro.plate('times', num_times, dim=-1):
        ...
1 Like

This worked - thank you!