RuntimeError: The following operation failed in the TorchScript interpreter. RuntimeError: Unsupported value kind: Tensor

Hi I’m having issues with jit_compile option on NUTS sampler, hoping someone has a work around!

I get the runtime error when I run the following code block

from torch import nn
from pyro.nn import PyroSample, PyroModule

class BayesianLognormalRegression(PyroModule):
    def __init__(self, *, in_features, out_features = 1, bias = True):
        self.sigma = None
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        if bias:
          self.linear.bias = PyroSample(dist.Normal(0., 5.).expand([out_features]).to_event(1))
        self.linear.weight = PyroSample(dist.Normal(0., 5.).expand([out_features, in_features]).to_event(2))

    def forward(self, data, target=None):
        mu = self.linear(data).squeeze(-1)
        sigma = pyro.sample("sigma", dist.Exponential(1.))
        with pyro.plate("data", data.shape[0]):
            obs = pyro.sample("obs", dist.LogNormal(loc=mu, scale=sigma), obs=target)
        return mu

model = BayesianLognormalRegression(in_features=1)
from pyro.infer import MCMC, NUTS
nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True) 
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200), target)

Full error here:

RuntimeError                              Traceback (most recent call last)
<ipython-input-76-70e7e5727ba4> in <module>()
      4 nuts_kernel = NUTS(model, jit_compile=True, ignore_jit_warnings=True) 
      5 mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
----> 6, target)

12 frames
/usr/local/lib/python3.7/dist-packages/pyro/infer/mcmc/ in _potential_fn_jit(self, skip_jit_warnings, jit_options, params)
    293         if self._compiled_fn:
--> 294             return self._compiled_fn(*vals)
    296         with pyro.validation_enabled(False):

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: Unsupported value kind: Tensor

However, if I hardcode sigma=1.0, the code runs!

Can anyone help? Thanks!

torch jit is finicky and probably doesn’t like some aspect of PyroModule. i suggest you write your model as a simple python function as is done e.g. here

ok, I won’t worry too much about it, then. Thanks.