Last time I used Pyro was few years back when the PyroModule
interface wasn’t there. Now I am back to Pyro again and trying to use it to replicate the coin fairness example. I am following the examples here for understanding the PyroModule
interface.
device = torch.device('cpu')
data = torch.rand(1000, device=device)
data[data>0.7] = 1.; data[data<=0.7] = 0.
class Model(PyroModule):
def __init__(self):
super().__init__()
self.fairness = PyroSample(prior=dist.Beta(10., 10.))
def forward(self, data=None):
with pyro.plate("obs", len(data)):
return pyro.sample("data", self.fairness, obs=data)
class Guide(PyroModule):
def __init__(self):
super().__init__()
self.alpha = PyroParam(torch.tensor(15.), constraint=dist.constraints.positive)
self.beta = PyroParam(torch.tensor(15.), constraint=dist.constraints.positive)
self.fairness = PyroSample(prior=dist.Beta(self.alpha, self.beta))
def forward(self, _):
return self.fairness
model = Model()
guide = Guide()
svi = infer.SVI(model, guide, optimizer, loss=infer.Trace_ELBO())
But I am pretty sure I am making some stupid mistake and that’s why upon svi.step(data)
, I am getting this error
Traceback (most recent call last):
File "coin_fairness.py", line 69, in <module>
L.append(svi.step(data))
File "/home/ayan/anaconda3/lib/python3.8/site-packages/pyro/infer/svi.py", line 128, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/home/ayan/anaconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py", line 126, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/home/ayan/anaconda3/lib/python3.8/site-packages/pyro/infer/elbo.py", line 170, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/home/ayan/anaconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py", line 52, in _get_trace
model_trace, guide_trace = get_importance_trace(
File "/home/ayan/anaconda3/lib/python3.8/site-packages/pyro/infer/enum.py", line 55, in get_importance_trace
model_trace.compute_log_prob()
File "/home/ayan/anaconda3/lib/python3.8/site-packages/pyro/poutine/trace_struct.py", line 216, in compute_log_prob
log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
AttributeError: 'Tensor' object has no attribute 'log_prob'
Can someone point me out what am I doing wrong ?