PyTorch 1.1 *extreme* SVI slowness

Has anyone else noticed extreme slowness when moving from PyTorch 1.0 to 1.1? I moved my pyro code using SVI from PyTorch 1.0 to 1.1 and found astonishing delays. I’m using the dev branch of pyro-ppl, and I’m using the google “deep learning” pytorch containers (so they get built by someone who knows what they’re doing there).

The worst instance of delays in in the JIT compiling. The JIT compile step is almost unbearably slow now. The compile step took almost 180 seconds on pytorch 1.0. It takes over 1000 seconds in 1.1. Disabling optimization during JIT provides little relief. SVI steps subsequent to the JIT step are also slower in 1.1 - about twice as slow as 1.0.

I’m also seeing about a 17% slowdown from 1.0 to 1.1 without JIT compiling.

Has anyone else seen this slowdown? I presume the google folks built their pytorch containers the same way, so hopefully the pytorch build isn’t the issue… I can roll back to 1.0 for now, but I’m wondering if something structural didn’t change that might need to be addressed. I’m guessing this isn’t limited to me??

Here are times from the 1.0 run:

print(F"torch version {torch.__version__}")
torch version 1.0.1.post2


reset_params()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
​
svi = SVI(model.model_fn, model.guide, optim, elbo)
with Timer('first step'):
    svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
with Timer('100 more steps'):
    for i in range(100):
        svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
[first step]  Elapsed: 2.387s
[100 more steps]  Elapsed: 232.172s


reset_params()
elbo = JitTraceEnum_ELBO(max_plate_nesting=1, jit_options={"optimize": True})
​
svi = SVI(model.model_fn, model.guide, optim, elbo)
with Timer('first step'):
    svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
with Timer('100 more steps'):
    for i in range(100):
        svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
[first step]  Elapsed: 176.492s
[100 more steps]  Elapsed: 83.862s


reset_params()
elbo = JitTraceEnum_ELBO(max_plate_nesting=1, jit_options={"optimize": False})
​
svi = SVI(model.model_fn, model.guide, optim, elbo)
with Timer('first step'):
    svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
with Timer('100 more steps'):
    for i in range(100):
        svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
[first step]  Elapsed: 164.959s
[100 more steps]  Elapsed: 163.283s

And here are the same times from the 1.1 run:

print(F"torch version {torch.__version__}")
torch version 1.1.0


reset_params()
elbo = TraceEnum_ELBO(max_plate_nesting=1)
​
svi = SVI(model.model_fn, model.guide, optim, elbo)
with Timer('first step'):
    svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
with Timer('100 more steps'):
    for i in range(100):
        svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
[first step]  Elapsed: 3.982s
[100 more steps]  Elapsed: 271.525s


reset_params()
elbo = JitTraceEnum_ELBO(max_plate_nesting=1, jit_options={"optimize": True})
​
svi = SVI(model.model_fn, model.guide, optim, elbo)
with Timer('first step'):
    svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
with Timer('100 more steps'):
    for i in range(100):
        svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
[first step]  Elapsed: 1013.402s
[100 more steps]  Elapsed: 190.331s


reset_params()
elbo = JitTraceEnum_ELBO(max_plate_nesting=1, jit_options={"optimize": False})
​
svi = SVI(model.model_fn, model.guide, optim, elbo)
with Timer('first step'):
    svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
with Timer('100 more steps'):
    for i in range(100):
        svi.step(yseq, xseq, length=len(yseq), **model_kwargs)
[first step]  Elapsed: 1041.228s
[100 more steps]  Elapsed: 185.235s

I’m not sure what the next step is for this. It seems like a general issue that could hit a lot of people. Not sure if the lack of response means that no one else is seeing this, or if it’s just my usual knack for writing posts that don’t get replies :slight_smile:

I suppose I can post to the PyTorch forum, but not sure if they’ll just look at it as an application specific issue.

Any suggestions from this forum first, as you know the use case best?

Thanks for sharing! Some other people have noticed a significant slowdown with JIT on pytorch-1.1.0. I noticed this too recently on an HMM example. I think it will be best to bring up this issue with the PyTorch developers. I am trying to create a minimal example that I can share with them, ideally something that is agnostic of Pyro. Your inputs and existing sleuth work will be really helpful. :slightly_smiling_face: If you were to JIT your model/guide code directly (i.e. outside the context of inference), do you still notice the slowdown?

Hi @chchchain, I too am disappointed by the extreme slowdown to torch.jit.trace in PyTorch 1.1, even when passing optimize=False. If you do open a PyTorch issue, please cc @neerajprad and me so we can join the discussion and share our profiling results.

Longer term we are exploring alternatives to relying on torch.jit.trace for large Pyro models. These include (1) JAX as a possible back-end for Pyro (numpyro), and (2) doing our own PPL-specific jitting including op fusion so we can feed torch.jit.trace with smaller programs operating on larger tensors (funsor).

OK, good to hear it’s not just me. Thanks for the forward looking view as well - I was getting a little unnerved to see performance going backward after spending so much time ramping up on Pyro.

I will run a few more tests to see if anything pops out as far as being especially affected (I lean heavily on pyro.markov and enumeration in this model). If so I’ll report back here. I’ll probably put an issue to PyTorch Monday or Tuesday.

1 Like

I was speaking with some pytorch devs and they suggested that we try out pytorch’s nightly build. With PyTorch nightly, the compilation time is improved, but I still notice that the execution time is quite a bit slower than 1.0.0. I am trying to profile to see where the bottleneck is. If you find something worth sharing, let us know!

PyTorch nightly

Min (sec) Mean (sec) Max (sec) python -O examples/hmm.py …
40.6 43.0 44.0 –model=4 --num-steps=50 --jit --time-compilation
14.8 15.6 16.6 –model=4 --num-steps=50 --jit --time-compilation (compilation time)

PyTorch 1.0.0

Min (sec) Mean (sec) Max (sec) python -O examples/hmm.py …
29.6 30.3 31.0 –model=4 --num-steps=50 --jit --time-compilation
12.6 12.8 12.9 –model=4 --num-steps=50 --jit --time-compilation (compilation time)

I definitely will - it’s near the top of my list. Right now my plan is to:

  1. Confirm that triggering JIT outside of inference is also slow (by simply calling the model and guide a la carte with the usual args, right?)
  2. Try to determine if either pyro.markov or enumeration are particularly problematic for JIT by stripping them out of the model as possible
  3. I’ll see if i can put together a pytorch nightly container to run in parallel as well, but I’m not sure how well that will go…

Sounds great! Reg. (1), on the few toy models that I tried out, I didn’t see any particular slowdown when directly running the JITed model, so my guess is that there is JIT regression in some operation that is related to our effect handling machinery. e.g. computing the trace’s log prob, but will be great to get an independent evaluation from you regarding (1) and (2). Thanks a lot!

  1. I’ll see if i can put together a pytorch nightly container to run in parallel as well, but I’m not sure how well that will go…

I found that the currently nightly package does not include some modules, but last week’s package seems to work pip install torch_nightly==1.2.0dev20190718 -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html.

Here’s the result of checking against the nightly build. (Other tests still on the way)

I tried to consolidate in a table with each cell containing the time of the first iteration nnn and the time of the next 100 iterations mmm like: nnn...mmm. There’s obviously randomness in a test like this (to the tune of several seconds), but some things seem pretty clear.

No JIT JIT - optim True JIT - optim False
1.0.1.post2 3.733…237.5 181.1…91.1 168.6…166.9
1.1 3.917… 261.9 1010.1…191.1 1044…186.5
1.2.0.dev20190718 3.931…265.8 242.4…186.7 237.9…185.2

TLDR; the nightly build is slower than 1.0 - sometimes twice as slow - but not nearly the JITastrophe of 1.1

1 Like

Thanks for sharing the profiling results, this is consistent with what I have been observing. I will try to isolate smaller code snippets (ideally independent of Pyro) that we can share with the PyTorch team. The regression in the execution time of the compiled code since 1.0 seems concerning, but good to see that the compilation time isn’t quite as bad as the last release.

More details on my earlier comment regarding profiling results for just running the model. Following is a profiling script of a simple gmm model that I am using. Please feel free to change the model, and play around with it. The script doesn’t run inference but merely computes the log density of the execution trace.

import torch

import pyro
from pyro import poutine
import pyro.distributions as dist
from pyro.util import ignore_jit_warnings, timed

D = 4


def model(data):
    initialize = pyro.sample("initialize", dist.Dirichlet(torch.ones(D)))
    with pyro.plate("states", 4):
        transition = pyro.sample("transition", dist.Dirichlet(torch.ones(D, D)))
        emission_loc = pyro.sample("emission_loc", dist.Normal(torch.zeros(D), torch.ones(D)))
        emission_scale = pyro.sample("emission_scale", dist.LogNormal(torch.zeros(D), torch.ones(D)))
    x = None
    with ignore_jit_warnings([("Iterating over a tensor", RuntimeWarning)]):
        for t, y in pyro.markov(enumerate(data)):
            x = pyro.sample("x_{}".format(t),
                            dist.Categorical(initialize if x is None else transition[x]),
                            infer={"enumerate": "parallel"})
            pyro.sample("y_{}".format(t), dist.Normal(emission_loc[x], emission_scale[x]), obs=y)
            
            
def _generate_data():
    transition_probs = torch.rand(D, D)
    emissions_loc = torch.arange(D, dtype=torch.Tensor().dtype)
    emissions_scale = 1.
    state = torch.tensor(1)
    obs = [dist.Normal(emissions_loc[state], emissions_scale).sample()]
    for _ in range(500):
        state = dist.Categorical(transition_probs[state]).sample()
        obs.append(dist.Normal(emissions_loc[state], emissions_scale).sample())
    return torch.stack(obs)


def fn(data):
    with ignore_jit_warnings():
        return poutine.trace(model).get_trace(data).log_prob_sum()


data = _generate_data()

print("**Torch version: {}**".format(torch.__version__))

with timed() as t:
    traced = torch.jit.trace(fn, (data,), check_trace=False)

print("compilation time", t.elapsed)

with timed() as t:
    for i in range(100):
        d =_generate_data()
        d.requires_grad_(True)
        out = fn(d)
        out.backward()

print("execution time (python)", t.elapsed)

with timed() as t:
    for i in range(100):
        d =_generate_data()
        d.requires_grad_(True)
        out = traced(d)
        out.backward()

print("execution time (compiled)", t.elapsed)

Results

Torch version: 1.0.0
compilation time 4.170818455982953
execution time (python) 15.539759353967384
execution time (compiled) 28.37991355289705

Torch version: 1.2.0.dev20190718
compilation time 4.694262339966372
execution time (python) 18.25298989005387
execution time (compiled) 28.691653871908784

At least on this model (for the restricted set of operations), there does not seem to be any regression when simply running the model forward. So either the regression is specific to certain models, or lies somewhere else in our inference machinery.

@chchchain - could you check if this holds true for your model too?

@neerajprad, here’s my best crack at this. If this is off target I’ll need some more guidance, as I’m still coming up to speed on Pyro/torch.

print("**Torch version: {}**".format(torch.__version__))

xseq.requires_grad_(True)
yseq.requires_grad_(True)

def model_wrapper(yseq, xseq):
   with ignore_jit_warnings():
       return poutine.trace(model.model_fn).get_trace(yseq, xseq, length=len(yseq), **model_kwargs).log_prob_sum()

with timed() as t:
   traced = torch.jit.trace(model_wrapper, (yseq, xseq), check_trace=False)
print(type(traced))
print("compilation time", t.elapsed)

with timed() as t:
   for i in range(10):
       out = model_wrapper(yseq, xseq)
       out.backward()
print("execution time (python)", t.elapsed)

with timed() as t:
   for i in range(10):
       out = traced(yseq, xseq)
       out.backward()
print("execution time (compiled)", t.elapsed)

Here are the results:
Torch version: 1.0.1.post2
<class ‘torch.jit.TopLevelTracedModule’>
compilation time 26.324513617902994
execution time (python) 8.057717148214579
execution time (compiled) 32.86182566732168

Torch version: 1.1.0
<class ‘torch._C.Function’>
compilation time 129.31766010262072
execution time (python) 7.960676901042461
execution time (compiled) 241.89978348650038

Torch version: 1.2.0.dev20190718
<class ‘torch._C.Function’>
compilation time 24.93151024915278
execution time (python) 7.681708190590143
execution time (compiled) 28.462130896747112

I suppose this could be mostly read as good news - the 1.2 performance is about the same as the 1.0. (Perhaps a shade faster, but close to the variance of the several runs I made.)

Notes:

  • If comparing runs with yours, note the scale - I brought the loop down to 10 iterations b/c it was so long
  • I stuck all of the kwargs in a wrapper fn since the trace only wanted tensors
  • The new timed() is nice, but it eats exceptions. That covered a lot of issues when I was trying to set it up
  • I spent a lot of time trying to figure out wth was wrong with my setup, since my compiled times were slower than the python times. Only when coming back here to post the results did I see yours were too. Do you have any thoughts on that?

Thanks for sharing!

I suppose this could be mostly read as good news - the 1.2 performance is about the same as the 1.0. (Perhaps a shade faster, but close to the variance of the several runs I made.)

I think there is still the question of how to reconcile this with the slower results that we observe in terms of per-step inference time (e.g. 91 vs 186 s from your last profiling). It does seem then that the compiled code is slower with respect to some other operation (/operations) in SVI.

  • I spent a lot of time trying to figure out wth was wrong with my setup, since my compiled times were slower than the python times. Only when coming back here to post the results did I see yours were too. Do you have any thoughts on that?

I noticed that too, and it seems really odd for sure that JIT would end up making things worse. We will probably need to isolate this in PyTorch and share with their team.

I checked again and found that almost all the additional time over torch-1.0 is taken in the very first iteration. This is the only place that is currently slower as compared to pytorch-1.0. Adam Paszke alluded to this as well, and it seems that this is expected given all the additional complexity that has gone into JIT. On our end, this just means that JIT may not be worth it unless we are calling the compiled code several times to amortize over this additional time taken for compilation and first iteration.

Sorry for late reply @neerajprad, I was out last week.

A couple notes:

  • While the compiled code seems slower that the python times in this test, I definitely find that the JIT is much faster for my research runs. It’s so important that I’d almost call it make-or-break. I’m not sure how to reconcile the fact that this test shows the compiled slower than the python times vs my experience in the wild. I wish I was further along and could give something useful as to what is different between our test and the usual run. Willing to spend some time if you think it’s key (would need a little guidance) .
  • To elaborate on that a bit, my use case involves frequently compiling new models and then running to some level of convergence. The JIT + run-compiled-code option is much faster. So hearing “JIT may not be worth it” was worth a wince here.
  • To that end, the fact that the nightly build was falling back in line with 1.0 times was a bit heartening. My worst case scenario would be to be stuck on an aging branch if PyTorch had headed off in a direction that didn’t serve Pyro well.
  • But I’m also reminding myself of the results I put in the 8th post on this thread (Jul 23). It saw that the lion’s share of the slowdown was in the first iteration, but definitely not “almost all of” the slowdown (as you found). The compiled code was over twice as slow as the 1.0 version. Sadly, this was also true for the 1.2 nightly.

I noticed that too, and it seems really odd for sure that JIT would end up making things worse. We will probably need to isolate this in PyTorch and share with their team.

@neerajprad - just dropping a line to say I haven’t given up on this. With the dev branch moving off of 1.0 so soon it’s becoming a bigger deal to get this info to PyTorch and get it addressed.

But… this is peak summer travel and I’ll be AFK here again until late next week. I will circle back shortly after to see what you think would be most productive to send upstream and will get on it. Top of my list is why the 1.2 nightly is running compiled code twice as slow as 1.0.

Thanks for the update! Btw, pytorch-1.2 got released today, and we will shortly have a pyro release that uses pytorch-1.2. Once that is done, could you check your benchmarks again and see how it compares with pytorch-1.0.

The JIT + run-compiled-code option is much faster. So hearing “JIT may not be worth it” was worth a wince here.

:slightly_smiling_face: I use JIT for HMC in Pyro almost all the time, because it easily leads to a 4x difference in performance, but then HMC falls precisely in that sweet spot where some additional time spent in model compilation and first run is easily amortized over many thousands of forward / backward passes through the model.

It saw that the lion’s share of the slowdown was in the first iteration, but definitely not “almost all of” the slowdown (as you found).

It will be great if you could isolate and share your code that demonstrates this issue with the compiled code being slower past the first iteration. In all the examples that I played around with, I only found the first iteration to be slow and the subsequent iterations to actually be faster. Btw, are you running this on CPU or GPU?