Hello,
I am trying to model a linear single degree of freedom through the Pyro, and have a VI inference for the global variables of xi, wn, and local variables of x and v (including in the z vector). I wrote the following model, and it works. But its really slow in iterations (compared to the torch). I tried a lot of different options for speeding up the run, such as numerating from the end, changing the Trace_elbo to Jit_Trace_elbo to run in different backend, and using autoguide and custom guide. Some ameliorated the condition a bit, but still have the same problem. I wonder if my model is not written in an optimized way for a VI inference. (I have also read the hmm documents, but I couldnt catch what change should I make to my code). I would be really grateful if you can recommend me some options or changes that I should make in my code or the settings of the inference to decrease the run time.
def model(Qnoise, Rnoise, delta, respAcc, inpAcc):
T = respAcc.shape[0]
# Parameters distributions
xi = pyro.sample('xi', dist.Uniform(0.0, 1.0)) # two-side constraint
wn = pyro.sample('wn', dist.LogNormal(-1.0, 0.4))
# Prior distributions
z_t = pyro.sample('z_0', dist.MultivariateNormal(torch.zeros(2), torch.diag(Qnoise)))
a_t = pyro.sample('a_0', dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[0])
for t in range(1, T):
## Transition ##
z_t = pyro.sample(f"z_{t}", dist.MultivariateNormal(fx(z_t.squeeze(0), xi.squeeze(0), wn.squeeze(0),
delta, inpAcc[t-1]), torch.diag(Qnoise)))
## Observation ##
a_t = pyro.sample(f"a_{t}", dist.Normal(torch.matmul(z_t.squeeze(0),torch.tensor([[-torch.square(wn)],[-(2*xi*wn)]])), Rnoise), obs=respAcc[t])
For the declaration, I use " pyro.infer.SVI" with “pyro.optim.Adam” optimizer and “AutoNormal” as the guide.