SVI vectorize particles, model parallelization

I am writing a fairly simple dynamical model that has a Markov dependency structure where the state at time t depends only on the state at time t-1. This is my model and guide:

def model(prior, var, data, F):
    prior1 = prior

    for t in range(len(data)):
        state = pyro.sample("state_{}".format(t), dist.MultivariateNormal(prior1,var))
        pyro.sample("measurement_{}".format(t), dist.MultivariateNormal(state, 0.05*torch.eye(4)), obs = data[t])
        prior1 = F@state.T

def guide(prior, var, data, F):

    for t in range(len(data)):
        x_value = pyro.param("x_value_{}".format(t), Variable(prior, requires_grad=True))
        var_value = pyro.param("var_value_{}".format(t), Variable(var, requires_grad=True), constraint=dist.constraints.lower_cholesky)
        state =  pyro.sample("state_{}".format(t), dist.MultivariateNormal(x_value, scale_tril=var_value))
        prior = F@state.T 

I am running the stochastic variational inference procedure on my model and guide that looks like this:

svi = pyro.infer.SVI(model=model, 
                     guide=guide,
                     optim = optim,
                     loss=pyro.infer.Trace_ELBO(num_particles=2, vectorize_particles=False, retain_graph=True, strict_enumeration_warning=True))

I want to improve my results by increasing the number of particles, however this makes the program slow. I could increase the speed of the program by parallelization that is conducted if you set vectorize_particles=True. However my model and guide cannot handle this, probably because I have a for-loop in my model and guide instead of a pyro.plate structure.

Is a pyro.plate structure possible for a model that uses the results from a previous time step (t-1)? And if so how?
How should I change my model to allow for vectorize_particles=True?
Another question I have is, is pyro.plate possible if you have this dependency structure?

Thanks in advance for your answer.

Hi @angelique, thanks for your patience, many of us Pyro devs have been vacationing :smile:

I think your easiest option is to make your model vectorizable. You should be able to do this by supporting broadcasting on the left. I think the only issue with your model is the matrix multiply op F @ state.T. I usually support broadcasting by post-multiplying by the transformation (postmodern notation), prior1 = state @ F.T.

Another faster option for your linear-Gaussian state space model might be to use Pyro’s built-in GaussianHMM class, which uses some fancy parallel-scan algorithms and should be way faster if you’re running on GPU. Note that GaussianHMM also assumes the post-multiplication form state @ F.T and includes an extra state transition at the beginning rather than the end.

1 Like