I was trying to use jit trace
for GP model saving/serving, can anyone give me some hits?
I started with the model sv-dkl.py
as in pyro/sv-dkl.py at 7692a503b64dc8200049e1c24a775d2170ba8039 · pyro-ppl/pyro · GitHub , added a function Predict
which is identical to the one in Bayesian Tutorial: pyro/bayesian_regression.ipynb at dev · pyro-ppl/pyro · GitHub
class Predict(torch.nn.Module):
def __init__(self, model, guide):
super().__init__()
self.model = model
self.guide = guide
def forward(self, *args, **kwargs):
samples = {}
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(
self.model, guide_trace)).get_trace(*args, **kwargs)
for site in prune_subsample_sites(model_trace).stochastic_nodes:
samples[site] = model_trace.nodes[site]['value']
return tuple(v for _, v in sorted(samples.items()))
And the save module function also similar to what it is in BayesianRegression tutorial:
def save_module(args, gpmodule):
predict_fn = Predict(gpmodule.model, gpmodule.guide)
predict_module = torch.jit.trace_module(
predict_fn, {"forward": (gpmodule.X,)}, check_trace=False)
torch.jit.save(predict_module, '/tmp/reg_predict.pt')
This is the error I got:
Traceback (most recent call last):
File "sv-dkl.py", line 221, in <module>
main(args)
File "sv-dkl.py", line 187, in main
save_module(args, gpmodule)
File "sv-dkl.py", line 117, in save_module
predict_fn, {"forward": (gpmodule.X,)}, check_trace=False)
File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/torch/jit/__init__.py", line 997, in trace_module
module._c._create_method_from_trace(method_name, func, example_inputs, var_lookup_fn, _force_outplace)
File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/torch/nn/modules/module.py", line 539, in __call__
result = self._slow_forward(*input, **kwargs)
File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/torch/nn/modules/module.py", line 525, in _slow_forward
result = self.forward(*input, **kwargs)
File "sv-dkl.py", line 51, in forward
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py", line 177, in get_trace
self(*args, **kwargs)
File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/pyro/poutine/trace_messenger.py", line 157, in __call__
ret = self.fn(*args, **kwargs)
File "/home/{user}/miniconda3/envs/py36-pyro-tutorial/lib/python3.6/site-packages/pyro/nn/module.py", line 471, in cached_fn
return fn(self, *args, **kwargs)
TypeError: guide() takes 1 positional argument but 2 were given
Any hints how to resolve this issue? I noticed the elbo = infer.JitTraceMeanField_ELBO()
option not working and created a bug [Bug] sv-dkl.py failed when enable jit trace · Issue #2255 · pyro-ppl/pyro · GitHub and use infer.TraceMeanField_ELBO()
instead in training, would this bug be associate with the issue I listed above?
Regards,