Thank you. It ran. Should I add the full_size into the Predictive
as well? It doesn’t seem to work.
When I run
predictive = Predictive(model, guide=guide, num_samples=1000,
return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_discrete)
I get the following:
Traceback (most recent call last):
File "<ipython-input-156-c6d7a45b4038>", line 1, in <module>
samples = predictive(x_discrete)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\predictive.py", line 201, in forward
parallel=self.parallel, model_args=args, model_kwargs=kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\predictive.py", line 53, in _predictive
max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\predictive.py", line 21, in _guess_max_plate_nesting
model_trace = poutine.trace(model).get_trace(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 187, in get_trace
self(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\nn\module.py", line 413, in __call__
return super().__call__(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\torch\nn\modules\module.py", line 722, in _call_impl
result = self.forward(*input, **kwargs)
TypeError: forward() missing 1 required positional argument: 'x'
I thought I might need the full_size
in the predictive but that gives me:
RuntimeError: Expected object of scalar type Float but got scalar type Double for argument #2 'mat1' in call to _th_addmm
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
linear.weight dist | 1 82
value | 1 82
linear.bias dist | 1
value | 1