Hi I’m having issues with jit_compile option on NUTS sampler, hoping someone has a work around!
I get the runtime error when I run the following code block
from torch import nn
from pyro.nn import PyroSample, PyroModule
class BayesianLognormalRegression(PyroModule):
def __init__(self, *, in_features, out_features = 1, bias = True):
super().__init__()
self.sigma = None
self.linear = PyroModule[nn.Linear](in_features, out_features)
if bias:
self.linear.bias = PyroSample(dist.Normal(0., 5.).expand([out_features]).to_event(1))
self.linear.weight = PyroSample(dist.Normal(0., 5.).expand([out_features, in_features]).to_event(2))
def forward(self, data, target=None):
mu = self.linear(data).squeeze(-1)
sigma = pyro.sample("sigma", dist.Exponential(1.))
with pyro.plate("data", data.shape[0]):
obs = pyro.sample("obs", dist.LogNormal(loc=mu, scale=sigma), obs=target)
return mu
model = BayesianLognormalRegression(in_features=1)
from pyro.infer import MCMC, NUTS
nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(data, target)
Full error here:
RuntimeError Traceback (most recent call last)
<ipython-input-76-70e7e5727ba4> in <module>()
4 nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True)
5 mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
----> 6 mcmc.run(data, target)
12 frames
/usr/local/lib/python3.7/dist-packages/pyro/infer/mcmc/util.py in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
292
293 if self._compiled_fn:
--> 294 return self._compiled_fn(*vals)
295
296 with pyro.validation_enabled(False):
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Unsupported value kind: Tensor
However, if I hardcode sigma=1.0, the code runs!
Can anyone help? Thanks!