Best practices and performance for time series models?

I work with time series models in MCMC. I’ve been experimenting with pyro for implementation (as opposed to Stan). I’m seeing performance that is… well, at the beginning of the sampling, I see about 2 seconds per iteration. Then it drops quickly to multiple iterations per second. And then it goes back up again. By iteration 24, its at 95 seconds per iteration. By iteration 30, it’ll be at 240 seconds+ per iteration.

I have a suspicion about the cause. I’m hoping someone can validate it and propose an alternative approach.

In structural time series modeling with MCMC, its been established that a key to getting good exploration of the posterior is to parameterize the model so that all of the samples are 0-centered. As a result, I have a lot of structures that look something like this:

def lineartrend(ar, M [dimensionality], T [number of periods]):
… skipping the priors

mu = proto.new_zeros(M)
with pyro.plate("lineartrend", T):
    nu = pyro.sample("nu_ar", dist.MultivariateNormal(mu, scale_tril=L_Omega_ar))
delta_stack = [  delta_t0  ]
for t in poutine.markov(range(1, T), history=ar):
    if ar is 1:
        delta = alpha_ar + (beta_ar * (delta_stack[-1] - alpha_ar)) + nu[t]
    else:
        past_delta = torch.stack(delta[max(0, t - ar):t], 0)
        delta = alpha_ar + (beta_ar[-past_delta.shape[0]:] * past_delta).sum(0) + nu[t]
    delta_stack.append(delta)
return torch.stack(delta_stack, 0)

My suspicion is that what’s causing the performance trouble is all the calling to torch.stack fragmenting up the memory. I can’t pre-allocate a matrix because then I’ll get an error about in-place modification when backward() is called. And I can’t just keep the recent history and toss the rest, because I need the values at each moment in time.

Is there some other approach to coding a model like this that can offer improved performance?

(In the time its taken to write this message, the experiment made it from iteration 24 to 27, and is now at 220 seconds/iteration.)

Edit: I think its appropriate to elaborate on the performance issue by comparison to the same model in Stan. I’m not doing this to be critical, but I think its important to clarify the issue. The performance I’ve described is with Pyro running a single chain. Because pytorch uses a BLAS, the model fully consumes all 6 cores on this machine (I see cpu usage of around 850-900%, including hyperthreading). By comparison, Stan is able to execute 2000 iterations of the same model, using the same dataset, on a single core, in about an hour.

I just have some small experiences for MCMC with time series data. So I wish I could learn from your experiment. In the mean time, I suggest to:

  • set max_tree_depth = 8 or smaller to improve speed (this way has some theoretical trade-offs but it can help some cases)
  • use jit_compile=True
  • about in-place modification, you can use torch.autograd.anomaly_detection to see where is the problem. A solution for it is to clone some variables.
  • I can’t recall how to limit CPU usages with pytorch, but you can set something like torch.set_num_threads(...) or to add MKL_NUM_THREADS=1, OMP_NUM_THREADS=1 in front of your script.
  • I see no observation variable in your code. You might want to add it I guess.

Hi @amos, nice model! I don’t know much about Pyro’s HMC, but one trick I’ve used to avoid torch.stack() is to use .clone() after each time I read from the big tensor. Does this work:

...
delta = proto.new_empty(T, M)
delta[0] = delta_t0
for t in poutine.markov(range(1, T), history=ar):
    if ar == 1:
        delta[t] = alpha_ar + (beta_ar * (delta[t-1].clone() - alpha_ar) + nu[t]  # note the clone
    else:
        past_delta = delta[max(0, t-ar):t].clone()  # note the clone
        delta[t] = ...
   return delta

Also after tuning HMC parameters following @fehiepsi’s advice, you might also try jitting your model.

@amos - We have observed that Pyro’s performance as compared to Stan on NUTS is an order of magnitude slower, specially on smaller models. Most of the overhead after JITting is in PyTorch, which is hard to optimize on Pyro’s end.

Two things that I have found to be helpful - using JIT as much as possible, and using a lower tree_depth e.g. 6-8 (ideally we should do this by default during adaptation). This is really important because if you are in a space where you need to build a tree of depth 10, Python and PyTorch’s overhead will be so high that you won’t be making much progress. It will be better to have shorter and faster trajectories until you get into the typical set and have a mass matrix and step size that allows you to efficient explore.

Btw, if you are using HMC, you won’t need poutine.markov since HMC doesn’t make use of the markov context (only SVI does).

Thanks for the suggestions.

Unfortunately, models like these are known to require higher tree depths to properly explore the posterior. Typically in Stan I would set a max_treedepth of 15 and adapt_delta to 0.99 to hold down the step size.

That’s especially true if you use hierarchical shrinkage priors, which was one of the goals of the experiment. (Another goal was to see how Pyro would enable better code organization; the code I posted was not a model - a BSTS model has 5 or so distinct time series components that get combined into a model. The code I posted was one of those 5 components, not the whole model.)

If you guys decide to start looking deeply into Pyro’s NUTS/HMC performance, let me know - this model, which is a pretty vanilla implementation of an MBSTS, might be a good test case for you.