Time Series model dimensions over 32

I am trying to build a time series model with some initial parameters, and a latent categorical sampling that is dependent on the outcome of the previous sample. I am quickly running into a dimension error when trying to perform inference using MCMC on the latent initial parameters of the model. The error I am getting is:

ValueError: number of dimensions must be within [0, 32]

I was wondering if anyone had a suggestion on a better way to build the model/guide to avoid this dimensionality issue. My model is currently setup as follows:

def model(data):
    
    ## Latent model parameters. Interested in inferring rate
    ## Rate parameter
    rate = pyro.param("rate", torch.tensor(.5),
                     constraint=pyro.distributions.constraints.positive)

    r1 = pyro.sample("r1", pyro.distributions.Normal(rate, .01))

    ## starting points
    s1_start = pyro.sample("s1_start", pyro.distributions.Normal(10, .01))
    s2_start = pyro.sample("s2_start", pyro.distributions.Normal(10, .01))
    s3_start = pyro.sample("s3_start", pyro.distributions.Normal(10, .01))

    ## Loop over each observed trace
    for i in range(data.shape[0]):
        ## Each step in trace
        for t in range(data.shape[1]):

            # latent categorical sample
            sample = pyro.sample("sample_{0}_{1}".format(str(t), str(i)), 
                                     pyro.distributions.Categorical(
                                         torch.tensor([s1_start * r1, 
                                                       s2_start * .1, 
                                                       s3_start * .9])))

            ## Observed trace 
            update = np.zeros(3)
            update[sample] = 1
            s1_temp = pyro.sample("s1_{0}_{1}".format(str(t), str(i)), 
                               pyro.distributions.Normal(s1_start + update[0], .01), 
                               obs = data[i, t, 0])
            s2_temp = pyro.sample("s2_{0}_{1}".format(str(t), str(i)), 
                               pyro.distributions.Normal(s2_start + update[1], .01), 
                               obs = data[i, t, 1])
            s3_temp = pyro.sample("s3_{0}_{1}".format(str(t), str(i)), 
                               pyro.distributions.Normal(s3_start + update[2], .01), 
                               obs = data[i, t, 2])

            s1_start = s1_temp
            s2_start = s2_temp
            s3_start = s3_temp

I have also tried to define a guide and use variational inference, but the rate parameter does not update when going that route. My idea for the guide was as follows:

def guide(data):
    
    ## Latent model parameters. Interested in infering rate
    #with pyro.plate("latent_rates"):
        
    ## Rate parameter
    rate = pyro.param("rate", torch.tensor(.5),
                     constraint=pyro.distributions.constraints.positive)

    r1 = pyro.sample("r1", pyro.distributions.Normal(rate, .01))

    ## starting points
    s1_start = pyro.sample("s1_start", pyro.distributions.Normal(10, .01))
    s2_start = pyro.sample("s2_start", pyro.distributions.Normal(10, .01))
    s3_start = pyro.sample("s3_start", pyro.distributions.Normal(10, .01))

    for i in range(data.shape[0]):
        
        for t in range(data.shape[1]):

            pyro.sample("sample_{0}_{1}".format(str(t), str(i)), 
                                     pyro.distributions.Categorical(
                                         torch.tensor([s1_start * r1, 
                                                       s2_start * .1, 
                                                       s3_start * .9])))

Any ideas on what I am doing incorrectly would be much appreciated!

As described in the enumeration tutorial, you probably need to wrap your loop iterators with pyro.markov:

for t in pyro.markov(range(data.shape[1]))

You should also replace any Numpy arrays in your model with PyTorch tensors (e.g. update).

Your model will likely be quite slow unless you vectorize over datapoints and observation dimensions using pyro.plate and Distribution.to_event as described in the tensor shapes tutorial:

...
s_start = pyro.sample("s_start", Normal(10, 0.01).expand([3]).to_event(1))

with pyro.plate("data", data.shape[0):
    for t in pyro.markov(range(data.shape[1])):
        sample = pyro.sample(f"sample_{t}", ...)
        ...
        s_temp = pyro.sample(f"s_{t}", Normal(s_start + update).to_event(1), obs=data[:, t])
        ...

You can see several different examples of related models in the hidden Markov model example.

Thanks for the help! I looked into those tutorials and came up with a model that generally seems to work, however I am still having some issues with the guides. To try to get them to work, I repurposed model0 from the hmm examples and used AutoDelta as my guide. When doing this I get the error:
RuntimeError: upper bound and larger bound inconsistent with step sign

I am kind of at a loss as to why this error is showing up since I followed the HMM code closely. My model also looks very similar to the Time Series example in the enumeration tutorial.

def model_0(sequences, batch_size=None):
    num_sequences, max_length, data_dim = sequences.shape
    with pyro.plate("starting_params"):
        # Sample rates with a prior mean
        rates = pyro.sample(
            "rates",
            pyro.distributions.Normal(0.0 + torch.tensor([.4, .1, .9]), .01).to_event(1),
        )
    
    # Set species values
    s_start = torch.tensor([10., 10., 10.])
    
    ## Loop over sequences
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences), batch_size):
        sequence = sequences[i, :]
        x = 0
        ## Loop over obs
        for t in pyro.markov(range(1, len(sequence))):
            
            x = pyro.sample(
                "x_{}_{}".format(i, t),
                pyro.distributions.Categorical(rates),
                infer={"enumerate": "parallel"},
            )
            update = torch.zeros(3)
            update[x] = 1
            s_start = s_start + update
            
            with tones_plate:
                pyro.sample(
                    "y_{}_{}".format(i, t),
                    pyro.distributions.Normal(s_start + update, .01).to_event(1),
                    obs=sequence[t],
                )

The guide and SVI code is as follows:

hmm_guide = pyro.infer.autoguide.AutoDelta(pyro.poutine.block(model_0, expose=["rates"]))
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model_0, hmm_guide, data)

Sorry, I’m not sure what’s going on. Please provide a complete stack trace and ideally a small, runnable example that reproduces the error, otherwise it’s hard for us to be helpful with debugging.

That makes sense. I’ve included a small runnable example below that reproduces the error. I am pretty sure the issue involves the update tensor. I’ve checked the dimensions and it seems to match before the error. Thanks again for the help!

import torch
import pyro

## Input Data
data = torch.tensor([[[10.0000, 10.0000, 10.0000],
                     [10.0095, 10.0109, 11.0034],
                     [11.0152, 10.0125, 10.9989],
                     [12.0170, 10.0256, 11.0039],
                     [12.0264, 10.0190, 12.0039]],

                    [[10.0000, 10.0000, 10.0000],
                     [10.9931, 10.0088, 10.0140],
                     [11.9897,  9.9931, 10.0095],
                     [12.0093, 10.0031, 11.0181],
                     [12.0030, 11.0080, 11.0244]],

                    [[10.0000, 10.0000, 10.0000],
                     [10.9937,  9.9889,  9.9821],
                     [10.9853, 10.9871,  9.9885],
                     [11.9836, 10.9890,  9.9896],
                     [11.9791, 10.9968, 10.9834]],

                    [[10.0000, 10.0000, 10.0000],
                     [11.0000, 10.0078, 10.0083],
                     [10.9925, 10.0019, 10.9896],
                     [10.9927, 10.0225, 11.9941],
                     [11.0191, 10.0299, 12.9845]],

                    [[10.0000, 10.0000, 10.0000],
                     [11.0067, 10.0266, 10.0160],
                     [12.0080, 10.0323, 10.0226],
                     [12.0002, 11.0434, 10.0300],
                     [11.9954, 11.0502, 11.0401]]])

## Define Model
def model_0(sequences, batch_size=None):
    num_sequences, max_length, data_dim = sequences.shape
    with pyro.plate("starting_params"):
        # Sample rates with a prior mean
        rates = pyro.sample(
            "rates",
            pyro.distributions.Normal(torch.tensor([.4, .1, .9]), .01).to_event(1),
        )

    ## Loop over sequences
    sample_plate = pyro.plate("sample", data_dim, dim=-1)
    for i in pyro.plate("sequences", num_sequences, batch_size):
        s_start = torch.tensor([10., 10., 10.])
        
        for t in pyro.markov(torch.arange(1, max_length)):

            x = pyro.sample(
                        "x_{}_{}".format(t, i),
                        pyro.distributions.Categorical(rates),
                        infer={"enumerate": "parallel"},
                    )
            
            ## Update tensor
            update = torch.zeros(3)
            update[x] = 1
            with sample_plate:
                
                ## Observed sample
                pyro.sample(
                    "y_{}_{}".format(t, i),
                    pyro.distributions.Normal(s_start + update, .01).to_event(1),
                    obs=data[i, t]
                )
            
            ## True update
            s_start = pyro.sample("d_{}_{}".format(t, i),pyro.distributions.Delta(s_start + update))

## Inference
hmm_guide = pyro.infer.autoguide.AutoDelta(pyro.poutine.block(model_0, expose=["rates"]))
pyro.clear_param_store()
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=1)
elbo.loss(model_0, hmm_guide, data)

@devonjkohler the cause of the error is the "starting_params" plate, which is both incorrect (because it is missing a size argument) and unnecessary because of the subsequent .to_event applied to rates. You can just delete that line.

To prevent your autoguide from controlling s_start, you should also use pyro.deterministic for s_start rather than pyro.sample. You will also need an event_dim=1 argument:

s_start = pyro.deterministic("d_{}_{}".format(t, i), s_start + update, event_dim=1)

These changes are sufficient to get your example running.

@eb8680_2 Thank you again for your help! I was able to get it to compile with your suggestions and refactored the code to include vectorization.

I wanted to post the code here (in case it can help someone else down the line) and I also wanted to get your input on some learning problems I am now encountering. Essentially the loss is not changing over SVI steps. The rates parameter is updating, but with small datasets (like I included in the example below) it converges to the prior, while large datasets converge to values that do not make sense. In both cases the loss does not change. I generated the data with rate values [.7, .3, .9], and put priors on the rates in the model of [.6, .3, .9]. I was curious if you thought the issue might be in my model or in how I am trying to do inference. I feel like it might be in the model because I tried following some of the SVI tips and tricks and nothing makes a difference.

import torch
import pyro

data = torch.tensor([[[10.9915,  9.9980, 10.0029],
         [10.9792, 10.0190, 11.0013],
         [11.9798, 10.0281, 10.9844],
         [11.9864, 10.0178, 11.9636],
         [11.9791, 10.0111, 12.9479],
         [12.9784,  9.9971, 12.9429],
         [12.9844,  9.9968, 13.9576],
         [13.9950,  9.9862, 13.9662],
         [13.9851,  9.9806, 14.9800]],

        [[10.9924, 10.0120,  9.9946],
         [10.9949, 11.0098,  9.9916],
         [12.0087, 11.0148, 10.0069],
         [12.0109, 11.0124, 11.0090],
         [12.0299, 12.0056, 11.0012],
         [12.0322, 12.9978, 11.0013],
         [12.0273, 13.9908, 11.0110],
         [12.0199, 13.9760, 12.0274],
         [13.0021, 13.9667, 12.0194]],

        [[10.0110,  9.9978, 11.0101],
         [11.0041,  9.9873, 11.0155],
         [10.9992,  9.9824, 12.0085],
         [12.0218,  9.9913, 11.9981],
         [12.0341,  9.9903, 12.9941],
         [13.0320,  9.9898, 12.9641],
         [13.0294,  9.9919, 13.9693],
         [14.0283,  9.9874, 13.9610],
         [14.0334,  9.9912, 14.9647]],

        [[10.9901, 10.0050,  9.9998],
         [11.9909, 10.0130, 10.0088],
         [13.0000, 10.0105,  9.9982],
         [14.0220, 10.0119,  9.9914],
         [14.0181, 10.0044, 10.9915],
         [15.0204,  9.9897, 10.9864],
         [15.0125,  9.9888, 11.9793],
         [15.0039,  9.9820, 12.9838],
         [16.0207,  9.9937, 12.9904]],

        [[10.9960, 10.0221, 10.0036],
         [11.9814, 10.0106, 10.0011],
         [11.9768, 10.9957,  9.9973],
         [11.9831, 10.9929, 10.9996],
         [11.9816, 10.9941, 11.9998],
         [11.9728, 11.9895, 11.9986],
         [11.9643, 12.9964, 12.0039],
         [12.9711, 13.0104, 11.9892],
         [12.9613, 13.0284, 13.0069]]])


def model_1(data, batch_size=None):
    
    num_sequences, max_length, data_dim = data.shape
    
    ## Rate priors
    rates = pyro.sample(
        "rates",
        pyro.distributions.Normal(torch.tensor([.6, .3, .9]), .001).to_event(1),
    )
    
    ## Starting points
    s_start = torch.tensor([10., 10., 10.])
    
    sample_plate = pyro.plate("sample", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:

        for t in pyro.markov(range(max_length)):

            x = pyro.sample(
                    "x_{}".format(t),
                    pyro.distributions.Categorical(rates),
                    infer={"enumerate": "parallel"},
                )

            ## Define update matrix
            update = torch.zeros(tuple(torch.cat((torch.tensor(x.shape[:-1]), torch.tensor([3])))))
            
            ## Update first run without enumeration
            if len(x.shape) == 2:
                update[torch.arange(len(update)), torch.flatten(x).long()] = 1.
            ## Update second run with enumeration. Cycles between 4 and 5 dimension tensor
            elif len(x.shape) == 4:
                update[torch.arange(len(update)), 0, 0,torch.flatten(x).long()] = 1.
            else:
                update[torch.arange(len(update)), 0, 0,0,torch.flatten(x).long()] = 1.
            
            ## Ensure `s_start` shape matches `update` matrix dimension.
            ## Enumeration run cycles between 4 and 5 dimensions, 
            ## `s_start` changes dimensions when combined.
            if s_start.shape != torch.Size([3]):
                s_start = s_start.reshape(update.shape)

            with sample_plate:
                sample = pyro.sample(
                    "y_{}".format(t),
                    pyro.distributions.Normal(s_start + update, .01),
                    obs=data[batch, t]
                )
            
            ## Update s_start for next step
            s_start = pyro.deterministic("d_{}".format(t), s_start + update)
            

hmm_guide = pyro.infer.autoguide.AutoDelta(pyro.poutine.block(model_1, expose=["rates"]))
pyro.clear_param_store()
elbo = pyro.infer.TraceEnum_ELBO(max_plate_nesting=3)
elbo.loss(model_1, hmm_guide, data) ## Ensure model compiles

## Infer rates
optim = pyro.optim.Adam({'lr': .001})
svi = pyro.infer.SVI(model_1, hmm_guide, optim, elbo)
for step in range(1000):
    loss = svi.step(data)
    if step %100 == 0:
        print(f"Loss: {loss}")
        for key, value in pyro.get_param_store().items():    
            print(f"{key}:\n{value}\n")