Sure. Here is a reduced version of the code from the tutorial:
import numpy as np
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
import traceback
import warnings
import sys
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file,'write') else sys.stderr
traceback.print_stack(file=log)
log.write(warnings.formatwarning(message, category, filename, lineno, line))
warnings.showwarning = warn_with_traceback
def scale(guess):
weight = pyro.sample("weight", dist.Normal(guess, 1.0))
return pyro.sample("measurement", dist.Normal(weight, 0.75))
conditioned_scale = pyro.condition(scale, data={"measurement": 9.5})
def scale_parametrized_guide(guess):
a = pyro.param("a", torch.tensor(guess))
b = pyro.param("b", torch.tensor(1.))
return pyro.sample("weight", dist.Normal(a, torch.abs(b)))
guess = torch.tensor(8.5)
pyro.clear_param_store()
svi = pyro.infer.SVI(model=conditioned_scale,
guide=scale_parametrized_guide,
optim=pyro.optim.SGD({"lr": 0.001, "momentum":0.1}),
loss=pyro.infer.Trace_ELBO())
svi.step(guess)
and here is the error, where it looks like the issue is the way the scale_parametrized_guide()
function defines a
:
File “C:\Users\beldaz\Anaconda3\lib\runpy.py”, line 193, in _run_module_as_main
“main”, mod_spec)
File “C:\Users\beldaz\Anaconda3\lib\runpy.py”, line 85, in _run_code
exec(code, run_globals)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel_launcher.py”, line 16, in
app.launch_new_instance()
File “C:\Users\beldaz\Anaconda3\lib\site-packages\traitlets\config\application.py”, line 658, in launch_instance
app.start()
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel\kernelapp.py”, line 486, in start
self.io_loop.start()
File “C:\Users\beldaz\Anaconda3\lib\site-packages\tornado\platform\asyncio.py”, line 127, in start
self.asyncio_loop.run_forever()
File “C:\Users\beldaz\Anaconda3\lib\asyncio\base_events.py”, line 422, in run_forever
self._run_once()
File “C:\Users\beldaz\Anaconda3\lib\asyncio\base_events.py”, line 1432, in _run_once
handle._run()
File “C:\Users\beldaz\Anaconda3\lib\asyncio\events.py”, line 145, in _run
self._callback(*self._args)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\tornado\platform\asyncio.py”, line 117, in _handle_events
handler_func(fileobj, events)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\tornado\stack_context.py”, line 276, in null_wrapper
return fn(*args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py”, line 450, in _handle_events
self._handle_recv()
File “C:\Users\beldaz\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py”, line 480, in _handle_recv
self._run_callback(callback, msg)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\zmq\eventloop\zmqstream.py”, line 432, in _run_callback
callback(*args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\tornado\stack_context.py”, line 276, in null_wrapper
return fn(*args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel\kernelbase.py”, line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel\kernelbase.py”, line 233, in dispatch_shell
handler(stream, idents, msg)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel\kernelbase.py”, line 399, in execute_request
user_expressions, allow_stdin)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel\ipkernel.py”, line 208, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel\zmqshell.py”, line 537, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py”, line 2662, in run_cell
raw_cell, store_history, silent, shell_futures)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py”, line 2785, in _run_cell
interactivity=interactivity, compiler=compiler, result=result)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py”, line 2909, in run_ast_nodes
if self.run_code(code, result):
File “C:\Users\beldaz\Anaconda3\lib\site-packages\IPython\core\interactiveshell.py”, line 2963, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File “”, line 38, in
svi.step(guess)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\infer\svi.py”, line 99, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\infer\trace_elbo.py”, line 125, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, *args, **kwargs):
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\infer\elbo.py”, line 164, in _get_traces
yield self._get_trace(model, guide, *args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\infer\trace_elbo.py”, line 52, in _get_trace
“flat”, self.max_plate_nesting, model, guide, *args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\infer\enum.py”, line 42, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\poutine\trace_messenger.py”, line 169, in get_trace
self(*args, **kwargs)
File “C:\Users\beldaz\Anaconda3\lib\site-packages\pyro\poutine\trace_messenger.py”, line 147, in call
ret = self.fn(*args, **kwargs)
File “”, line 27, in scale_parametrized_guide
a = pyro.param(“a”, torch.tensor(guess))
File “C:\Users\beldaz\Anaconda3\lib\warnings.py”, line 99, in showwarnmsg
msg.file, msg.line)
File “”, line 15, in warn_with_traceback
traceback.print_stack(file=log)
C:\Users\beldaz\Anaconda3\lib\site-packages\ipykernel_launcher.py:27: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad(True), rather than torch.tensor(sourceTensor).