How to set device for Pyro models?


I am trying to convert a PyTorch neural network model into a Bayesian Pyro model.
I want to avoid the StopIteration error (see below) that arises when I execute the command next(myPyTorchModel.parameters()).device , where myPyTorchModel is a Pyro model.

To be more specific, before converting my PyTorch model into a Pyro model, when I execute the command next(myPyTorchModel.parameters()).device, I get the following output:


Whereas when I execute the same command after converting my PyTorch model into a Pyro model, I get output like below:


for m in myPyTorchModel.modules():
    for name, value in list(m.named_parameters(recurse=False)):
        setattr(m, name, module.PyroSample(prior=dist.Normal(0, 1)

OUTPUT:   File "<ipython-input-9-be50535fd794>", line 1, in <module>


How can I prevent this StopIteration error with Pyro models?
Thank you,

My guess is, you can pass a dictionary like below

options =  dict(dtype=input.dtype, device=input.device)

to sample statements like below within the model function

  prior_loc = torch.zeros(batch_size, dim1, **options)
  prior_scale = torch.ones(batch_size, dim1, **options)
  zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))


Thank you very much for your reply.
Could you be more specific about your code? I am new to Python, and plus I am new to Pyro, so I am not sure how I will integrate your answer to my code.

What are the input, batch_size, and dim1 in your code?

In your case, pyro.sample(dist.Normal(0,1)) would return one scalar value. In the example i’ve given, it returns a vector of sampled elements. If you want a scalar, you can just set batch_size = 1 and dim1 = 1.