Jit Save issue for GP module

I was trying to use jit trace for GP model saving/serving, can anyone give me some hits?

I started with the model sv-dkl.py as in https://github.com/pyro-ppl/pyro/blob/7692a503b64dc8200049e1c24a775d2170ba8039/examples/contrib/gp/sv-dkl.py , added a function Predict which is identical to the one in Bayesian Tutorial: https://github.com/pyro-ppl/pyro/blob/dev/tutorial/source/bayesian_regression.ipynb

class Predict(torch.nn.Module):
    def __init__(self, model, guide):
        self.model = model
        self.guide = guide

    def forward(self, *args, **kwargs):
        samples = {}
        guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(poutine.replay(
            self.model, guide_trace)).get_trace(*args, **kwargs)
        for site in prune_subsample_sites(model_trace).stochastic_nodes:
            samples[site] = model_trace.nodes[site]['value']
        return tuple(v for _, v in sorted(samples.items()))

And the save module function also similar to what it is in BayesianRegression tutorial:

def save_module(args, gpmodule):
    predict_fn = Predict(gpmodule.model, gpmodule.guide)
    predict_module = torch.jit.trace_module(
        predict_fn, {"forward": (gpmodule.X,)}, check_trace=False)
    torch.jit.save(predict_module, '/tmp/reg_predict.pt')

This is the error I got:

Traceback (most recent call last):
  File "sv-dkl.py", line 221, in <module>
  File "sv-dkl.py", line 187, in main
    save_module(args, gpmodule)
  File "sv-dkl.py", line 117, in save_module
    predict_fn, {"forward": (gpmodule.X,)}, check_trace=False)
  File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/torch/jit/__init__.py", line 997, in trace_module
    module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
  File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/torch/nn/modules/module.py", line 539, in __call__
    result = self._slow_forward(*input, **kwargs)
  File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/torch/nn/modules/module.py", line 525, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "sv-dkl.py", line 51, in forward
    guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
  File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py", line 177, in get_trace
    self(*args, **kwargs)
  File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py", line 157, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/pyro/nn/module.py", line 471, in cached_fn
    return fn(self, *args, **kwargs)
TypeError: guide() takes 1 positional argument but 2 were given

Any hints how to resolve this issue? I noticed the elbo = infer.JitTraceMeanField_ELBO() option not working and created a bug https://github.com/pyro-ppl/pyro/issues/2255 and use infer.TraceMeanField_ELBO() instead in training, would this bug be associate with the issue I listed above?


Can you try inspecting the args? I wonder if self is being passed as an arg to the guide.