I have a Bayesian NN that is working when I do:
self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, in_features]).to_event(2))
But when I try making the variance a parameter like this:
var = PyroSample(dist.Exponential(1.))
self.fc1.weight = PyroSample(dist.Normal(0., var).expand([h1, in_features]).to_event(2))
I get the following error:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-42-1960cbe051d0> in <module>
1 features = ['net_ortg','net_drtg','lead','sec_rem_period','total_pts', 'p_1','p_2','p_3','p_4','p_5','p_6']
----> 2 model = ThreeLayerNNModel(11, 1, 24, 24, 24)
3 losses, guide = inference(model, features=features, epochs=20)
4 plt.plot(losses)
<ipython-input-40-19d2ca07f372> in __init__(self, in_features, out_features, h1, h2, h3)
5 self.fc1 = PyroModule[nn.Linear](in_features, h1)
6 var = PyroSample(dist.Exponential(1.))
----> 7 self.fc1.weight = PyroSample(dist.Normal(0., var).expand([h1, in_features]).to_event(2))
8 self.fc1.bias = PyroSample(dist.Normal(0., 10.).expand([h1]).to_event(1))
9 self.fc2 = PyroModule[nn.Linear](h1, h2)
~/pyro/xpm/pyro/lib/python3.7/site-packages/pyro/distributions/distribution.py in __call__(cls, *args, **kwargs)
15 if result is not None:
16 return result
---> 17 return super().__call__(*args, **kwargs)
18
19
~/pyro/xpm/pyro/lib/python3.7/site-packages/torch/distributions/normal.py in __init__(self, loc, scale, validate_args)
42
43 def __init__(self, loc, scale, validate_args=None):
---> 44 self.loc, self.scale = broadcast_all(loc, scale)
45 if isinstance(loc, Number) and isinstance(scale, Number):
46 batch_shape = torch.Size()
~/pyro/xpm/pyro/lib/python3.7/site-packages/torch/distributions/utils.py in broadcast_all(*values)
22 """
23 if not all(isinstance(v, torch.Tensor) or isinstance(v, Number) for v in values):
---> 24 raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
25 if not all([isinstance(v, torch.Tensor) for v in values]):
26 options = dict(dtype=torch.get_default_dtype())
ValueError: Input arguments must all be instances of numbers.Number or torch.tensor.
Any idea what is the issue?