How to set device for Pyro models?

My guess is, you can pass a dictionary like below

options =  dict(dtype=input.dtype, device=input.device)

to sample statements like below within the model function

  prior_loc = torch.zeros(batch_size, dim1, **options)
  prior_scale = torch.ones(batch_size, dim1, **options)
  zs = pyro.sample("z", dist.Normal(prior_loc, prior_scale).to_event(1))