Invalid index in gather for model relying on auto-enumeration

Dear Pyro developers,

I have been trying to use discrete site enumeration, for a model with multiple interactions between discrete variables.
When running SVI on the model, I get an error about invalid index (see below) internally in the log_probability calculation of the Categorical distribution. I am wondering whether this is because I did some misspecification in my model, or perhaps I am using some wrong assumption about how discrete site enumeration should work.

I have tried debugging for some time, but I could not find an immediate solution.
Could you kindly assist me with some suggestions on how to fix this when you have time?

Thank you very much in advance!

Debugging information:
I am running the latest pytorch-1.0 preview and on the latest commit at the pytorch-1.0 branch in Pyro.

Relevant model/guide code:

    def mg_gen(is_guide=False):
    def mg(data):
        clust, corr, vals, lengths = data
        clust = clust.float()
        vals = vals.t().float()
        corr = corr.float()
        n_clusts = clust.size(1)
        n_corr = 2
        n_class = 4
        # Pseudo counts for each cluster and correctness choice, specifying Dirichlet arguments
        # for each class of output
        pseudocounts = clust.new_ones(n_clusts, n_corr, n_class).float() * 0.5
        if is_guide:
            pseudocounts = pyro.param('pc_q', pseudocounts)
        # Dirichlet prior over output classes (0-3)
        obs_dir = pyro.sample('obs_dir', dist.Dirichlet(pseudocounts).independent(2))
        feature_plate = pyro.plate('feature', vals.size(0), dim=-2)
        with pyro.plate('data', vals.size(1), dim=-1) as idx:
            c0corr = vals.new_tensor(4.0).float()
            c1corr = vals.new_tensor(1.0).float()
            if is_guide:
                c0corr = pyro.param('c0corr', c0corr, constraint=constraints.greater_than(1e-3))
                c1corr = pyro.param('c1corr', c1corr, constraint=constraints.greater_than(1e-3))
            # Beta prior for classes that are correct
            corrpr = pyro.sample('corrpr', dist.Beta(concentration0=c0corr, concentration1=c1corr))
            corrobs = dict(obs=corr[idx]) if not is_guide else dict()
            # Labelled Bernoulli flip coin to determine whether output is "correct" or "incorrect" according to our specification
            corrch = pyro.sample('corrch', dist.Bernoulli(corrpr),
                                 infer=dict(enumerate='parallel', is_auxilliary=True), **corrobs).long()
            # Choice of cluster for data, given prior distribution over relevant clustering
            clustch = pyro.sample('clustch', dist.Categorical(clust.index_select(0, idx)),
                                  infer=dict(enumerate='parallel'))
            with feature_plate as fidx:
                # Observed values given choice of cluster and possible correctness
                if not is_guide:
                    pyro.sample('obs', dist.Categorical(obs_dir[clustch, corrch]), obs=vals)

    return mg

Error:

Traceback (most recent call last):
  File "/home/asal/src/project/pyro/poutine/trace_struct.py", line 149, in compute_log_prob
    site["log_prob"]
KeyError: 'log_prob'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "pyro_debug_enum.py", line 128, in <module>
    main(sys.argv)
  File "pyro_debug_enum.py", line 124, in main
    infer(dataloader)
  File "pyro_debug_enum.py", line 75, in infer
    total_epoch_loss_train = _train(svi, dataloader)
  File "pyro_debug_enum.py", line 62, in _train
    epoch_loss += svi.step(data)
  File "/home/asal/src/project/pyro/infer/svi.py", line 96, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
  File "/home/asal/src/project/pyro/infer/traceenum_elbo.py", line 336, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
  File "/home/asal/src/project/pyro/infer/traceenum_elbo.py", line 279, in _get_traces
    yield self._get_trace(model, guide, *args, **kwargs)
  File "/home/asal/src/project/pyro/infer/traceenum_elbo.py", line 233, in _get_trace
    "flat", self.max_plate_nesting, model, guide, *args, **kwargs)
  File "/home/asal/src/project/pyro/infer/enum.py", line 51, in get_importance_trace
    model_trace.compute_log_prob()
  File "/home/asal/src/project/pyro/poutine/trace_struct.py", line 151, in compute_log_prob
    log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
  File "/home/asal/anaconda2/envs/pyro-env/lib/python3.7/site-packages/torch/distributions/categorical.py", line 119, in log_prob
    return log_pmf.gather(-1, value).squeeze(-1)
RuntimeError: Invalid index in gather at /opt/conda/conda-bld/pytorch-nightly-cpu_1540796783109/work/aten/src/TH/generic/THTensorEvenMoreMath.cpp:453

You can also find the complete offending source file here:

Actually, I found out what the issue was.
I had an observed value that was outside the range of the prior of the categorical.

I apologize for creating an issue so eagerly. I will see whether it is possible to make a PR that add some constraints on the categorical elements that provide a nicer error message.

EDIT:
The fix is to change n_class to 5 instead of 4.

The original reason I wanted to build this model was to ask how I could potentially improve the performance of a similar more complex model. Even for this relatively β€œsimple” model, it seems to take around 10 seconds per iteration to run SVI, and so over 1 hour to run around 500 iterations.
I was wondering whether if I specified my model incorrectly and am hitting a pathological case here, and how I could potentially speed the execution up?

Thanks again!

When you say β€œiteration”, it looks from your gist like you actually mean β€œepoch”, where an epoch is a sequence of 1000 calls to svi.step with a batch size of 10. That means each gradient step is taking ~10ms. Some simple things to try are increasing your batch size, disabling Pyro’s validation, and, since you’re using PyTorch 1.0, using JitTraceEnum_ELBO instead of TraceEnum_ELBO, though I think the JIT may not work correctly here right now due to this issue: trace's log_prob of models with Gamma/Dirichlet distributions incompatible with JIT Β· Issue #1487 Β· pyro-ppl/pyro Β· GitHub.

It would also be helpful if you could profile one epoch of your code with and without JitTraceEnum_ELBO and post some results here. We still have a lot of profiling and performance work to do on enumeration, and there may still be issues in the PyTorch side (e.g. torch.einsum 400x slower than numpy.einsum on a simple contraction Β· Issue #10661 Β· pytorch/pytorch Β· GitHub ). However, the reality is that without the JIT, very small models have a performance ceiling determined mostly by CPython and PyTorch overhead.

1 Like

Awesome, thank you very much for all the suggestions! :grinning:

I will try to do some profiling with your suggestions and see how far I can get.

Hi again,

So I did some profiling and it seems to use a lot of time (around 50%) in the gather method used in TraceEnumELBO.
Large batch sizes did help with runtime, but accuracy surprisingly goes down if I use too large a batch.
For some reason, it seems that the JIT version runs almost twice as slow, which was suprising to me.

I have included the relevant files and debug information below. I would appreciate it if you could take a look at it when you have time.

Thanks!

Configuration runtimes:

Python file: pyro_debug_enum.py Β· GitHub

Debug runtime by calls:

Hi @ahmadsalim,
I have found the performance of the preview version of PyTorch to be quite bad w.r.t. the 0.4.0 version, specially with regard to distribution’s log_prob methods. See this issue that I filed with the PyTorch devs - Slowdown in distributions log_prob methods Β· Issue #12190 Β· pytorch/pytorch Β· GitHub. Many of these methods take twice as long on PyTorch master than on earlier versions of PyTorch. If it’s not too much trouble, could you run your profiling code using the 0.4.0 version of PyTorch? I would be curious to see if that runs faster. If you find that the gather operation is slower in the current preview version, it will be great if you could comment on the issue above so that we can have more visibility on that issue.

As @eb8680_2 mentioned, JIT is far from mature so while I wouldn’t have expected it to work well, at least for SVI it shouldn’t be running twice as slow, so this is something we will look into.

1 Like

Thanks for the response!
I will try doing the same benchmarks against the 0.4.0 version of Pyro and report back.

1 Like

It seems that it indeed runs on average around 0.3-0.5 seconds faster per iteration when I use batch size 250 and 0.7-1.0 seconds faster per iteration when I use batch size 1000 on 0.4.0 instead of 1.0.

I still have to try the latest dev version after PR #1507 was merged, to see if there is a difference there.

EDIT: As far as I could see, there was no significant difference in performance with the latest revision

Thanks for digging in! From your profiler results, does it look like the additional time is only taken by categorical.log_prob() or is there any other source of slowdown?

It seems that a major part of the slowdown between 0.4.0 and 1.0 is indeed categorical.log_prob().
It is taking on average almost twice as long (profiler information is taken as median out of 3 runs).

The auto-enumeration part is taking the rest of the time, which is still significant but not much different across versions.

Thanks for investigating and reporting back, @ahmadsalim! It seems that torch.gather is actually thrice as slow, which is really concerning. When you have the time, could you post your profiling results to Slowdown in distributions log_prob methods Β· Issue #12190 Β· pytorch/pytorch Β· GitHub as another example that might help the PyTorch devs debug this issue?

1 Like

Maybe gather issues are also related to https://github.com/pytorch/pytorch/pull/13420 and the issue mentioned in that PR?

1 Like

Sure, of course! :smiley:
I will add my profiling results to the issue.

Yeah, this seems more related to 13420. I asked on the issue, lets see.