Linear Dynamic System Modeling, Run Speed


I am trying to model a linear single degree of freedom through the Pyro, and have a VI inference for the global variables of xi, wn, and local variables of x and v (including in the z vector). I wrote the following model, and it works. But its really slow in iterations (compared to the torch). I tried a lot of different options for speeding up the run, such as numerating from the end, changing the Trace_elbo to Jit_Trace_elbo to run in different backend, and using autoguide and custom guide. Some ameliorated the condition a bit, but still have the same problem. I wonder if my model is not written in an optimized way for a VI inference. (I have also read the hmm documents, but I couldnt catch what change should I make to my code). I would be really grateful if you can recommend me some options or changes that I should make in my code or the settings of the inference to decrease the run time.

def model(Qnoise, Rnoise, delta, respAcc, inpAcc):
    T = respAcc.shape[0]
    # Parameters distributions
    xi = pyro.sample('xi', dist.Uniform(0.0, 1.0))  # two-side constraint
    wn = pyro.sample('wn', dist.LogNormal(-1.0, 0.4))
    # Prior distributions
    z_t = pyro.sample('z_0', dist.MultivariateNormal(torch.zeros(2), torch.diag(Qnoise)))
    a_t = pyro.sample('a_0', dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[0])
    for t in range(1, T):
        ## Transition ##
        z_t = pyro.sample(f"z_{t}", dist.MultivariateNormal(fx(z_t.squeeze(0), xi.squeeze(0), wn.squeeze(0),
                                                               delta, inpAcc[t-1]), torch.diag(Qnoise)))
        ## Observation ##
        a_t = pyro.sample(f"a_{t}", dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[t])

For the declaration, I use " pyro.infer.SVI" with “pyro.optim.Adam” optimizer and “AutoNormal” as the guide.

is fx linear because if so you can use GausianHMM instead which will be much faster

Thanks for the response.
I actually do not want to restrict the fx functionality to be just a linear form…

if you want speed i suggest using numpyro. jax is able to compile away a lot of overhead that pytorch cannot. as such it tends to be much faster for problems of this sort

I had a similar goal (Gaussian univariate state with unconventional multivariate measurement error). I found that using the lax.scan function with MCMC on numpyro was by far the fastest and could get close to the true parameters in a simulation study. A lot of the speedup actually came from passing the vectorized noise to the scan function, rather than to add the noise inside the transition function (although I might have been doing things wrong, I am new to (num)pyro).

Details and code are here: