Thank you!
My ultimate goal is to implement the switching linear dynamical system model that is in this paper: https://arxiv.org/pdf/1603.06277.pdf to analyze some video data. The paper however uses an inference procedure that seems to be largely incompatible with pyro but I was hoping that just doing variational inference would work. However, I haven’t worked with temporal data before so I was using the tutorials as examples.
The model looks like this
where z_n is a discrete markov chain and to get the y_i we use a nerual network just like with a VAE. My original idea was to first do the model side enumeration to marginalize out the z_i and then use a meanfield family for the x_i,A,B random variables (using a vanilla neural network for p(x_i|y_i)$ probabilities in the guide. Finally, I could use the infer_discrete function to the inference on the z_i. However, reading the dmm tutorial it seemed like it would make sense to use an rnn to properly condition on more than just the immediate “y_i” so that’s why I was asking.
I was a bit confused about the truncated backprop because as far as I remember when you are calling the rnn you detach the hidden state at some point backpropagate , detach, continue calling the rnn, backpropagate, and so on. But with pyro, at least as I understand it, when you call svi step it does the backpropagation through the network for us and I was unsure about how to control it in this way.
I am sorry if there is anything that is wrong here. I haven’t worked with temporal data before so that is why I am a bit confused.
If I do manage to successfully train the model I would be more than happy to turn it into a tutorial if someone finds it useful.
Thank you again.