Hello,
I am test-running a Pyro model like below:
model = nn.Sequential(
nn.Linear(28 * 28, 100),
nn.Sigmoid(),
nn.Linear(100, 100),
nn.Sigmoid(),
nn.Linear(100, 10),
)
module.to_pyro_module_(model)
for m in model.modules():
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, module.PyroSample(prior=dist.Normal(0, 1)
.expand(value.shape)
.to_event(value.dim())))
guide_diag_normal = guides.AutoDiagonalNormal(model)
optimizer_1 = Adam({"lr": 0.000000055})
scheduler_1 = pyro.optim.StepLR({'optimizer': optimizer_1, 'optim_args': {'lr': 0.000000055}})
svi_diag_normal = SVI(model, guide_diag_normal, optimizer_1, loss=Trace_ELBO())
input_ids=torch.tensor(random.normal(0,1,28*28))
# aritificial y value
y=torch.tensor(3.562684)
# error is generated here
svi_diag_normal.step(input_ids,y)
The error messages are:
Traceback (most recent call last):
File "<ipython-input-23-9f299eedcfa7>", line 1, in <module>
svi_diag_normal.step(input_ids,y)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/svi.py", line 128, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 126, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/elbo.py", line 170, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/trace_elbo.py", line 53, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/enum.py", line 44, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 185, in get_trace
self(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/nn/module.py", line 290, in __call__
return super().__call__(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 679, in forward
self._setup_prototype(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 819, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 577, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/infer/autoguide/guides.py", line 156, in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 185, in get_trace
self(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/poutine/messenger.py", line 11, in _context_wrap
return fn(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/pyro/nn/module.py", line 290, in __call__
return super().__call__(*args, **kwargs)
File "/Users/hyunjindominiquecho/opt/anaconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() takes 2 positional arguments but 3 were given
How can I avoid this error?
Thank you,