I work with time series models in MCMC. I’ve been experimenting with pyro for implementation (as opposed to Stan). I’m seeing performance that is… well, at the beginning of the sampling, I see about 2 seconds per iteration. Then it drops quickly to multiple iterations per second. And then it goes back up again. By iteration 24, its at 95 seconds per iteration. By iteration 30, it’ll be at 240 seconds+ per iteration.
I have a suspicion about the cause. I’m hoping someone can validate it and propose an alternative approach.
In structural time series modeling with MCMC, its been established that a key to getting good exploration of the posterior is to parameterize the model so that all of the samples are 0-centered. As a result, I have a lot of structures that look something like this:
def lineartrend(ar, M [dimensionality], T [number of periods]):
… skipping the priors
mu = proto.new_zeros(M)
with pyro.plate("lineartrend", T): nu = pyro.sample("nu_ar", dist.MultivariateNormal(mu, scale_tril=L_Omega_ar)) delta_stack = [ delta_t0 ]
for t in poutine.markov(range(1, T), history=ar): if ar is 1: delta = alpha_ar + (beta_ar * (delta_stack[-1] - alpha_ar)) + nu[t] else: past_delta = torch.stack(delta[max(0, t - ar):t], 0) delta = alpha_ar + (beta_ar[-past_delta.shape:] * past_delta).sum(0) + nu[t] delta_stack.append(delta)
return torch.stack(delta_stack, 0)
My suspicion is that what’s causing the performance trouble is all the calling to
torch.stack fragmenting up the memory. I can’t pre-allocate a matrix because then I’ll get an error about in-place modification when
backward() is called. And I can’t just keep the recent history and toss the rest, because I need the values at each moment in time.
Is there some other approach to coding a model like this that can offer improved performance?
(In the time its taken to write this message, the experiment made it from iteration 24 to 27, and is now at 220 seconds/iteration.)
Edit: I think its appropriate to elaborate on the performance issue by comparison to the same model in Stan. I’m not doing this to be critical, but I think its important to clarify the issue. The performance I’ve described is with Pyro running a single chain. Because pytorch uses a BLAS, the model fully consumes all 6 cores on this machine (I see cpu usage of around 850-900%, including hyperthreading). By comparison, Stan is able to execute 2000 iterations of the same model, using the same dataset, on a single core, in about an hour.