How to set device for Pyro models?

Hello,

I am trying to convert a PyTorch neural network model into a Bayesian Pyro model.
I want to avoid the StopIteration error (see below) that arises when I execute the command next(myPyTorchModel.parameters()).device , where myPyTorchModel is a Pyro model.

To be more specific, before converting my PyTorch model into a Pyro model, when I execute the command next(myPyTorchModel.parameters()).device, I get the following output:

device(type='cpu')

Whereas when I execute the same command after converting my PyTorch model into a Pyro model, I get output like below:


module.to_pyro_module_(myPyTorchModel)

for m in myPyTorchModel.modules():
    for name, value in list(m.named_parameters(recurse=False)):
        setattr(m, name, module.PyroSample(prior=dist.Normal(0, 1)
                                      .expand(value.shape)
                                      .to_event(value.dim())))


next(myPyTorchModel.parameters()).device
OUTPUT:   File "<ipython-input-9-be50535fd794>", line 1, in <module>
                  next(myPyTorchModel.parameters()).device

                  StopIteration

How can I prevent this StopIteration error with Pyro models?
Thank you,

My guess is, you can pass a dictionary like below

options =  dict(dtype=input.dtype, device=input.device)

to sample statements like below within the model function

  prior_loc = torch.zeros(batch_size, dim1, **options)
  prior_scale = torch.ones(batch_size, dim1, **options)
  zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))

Hello,

Thank you very much for your reply.
Could you be more specific about your code? I am new to Python, and plus I am new to Pyro, so I am not sure how I will integrate your answer to my code.

What are the input, batch_size, and dim1 in your code?

In your case, pyro.sample(dist.Normal(0,1)) would return one scalar value. In the example i’ve given, it returns a vector of sampled elements. If you want a scalar, you can just set batch_size = 1 and dim1 = 1.