Device mismatch in AutoNormalizingFlow

I am trying to use AutoNormalizingFlow for one subset of my parameters, and the others are taken care of via several other AutoGuides, all wrapped in an AutoGuideList. My data all exist on the GPU, and I have set torch.set_default_device("cuda"). Additionally, I have manually set the guide to GPU via guide.to("cuda").

When I go to actually run my model, I get a two devices error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[463], line 7
      4 gc.collect()
      5 torch.cuda.empty_cache()
----> 7 fit2_resx = run_x(a, b, b2, R, c, 
      8                      beta1=0.999, tol=1e-9, num_iter=500, lr0=5e-3)

Cell In[462], line 56, in fit_x(a, b, b2, c, R, num_iter, lr0, gamma, beta1, beta2, bh, verbose, tol)
     53 p = a.shape[0]
---> 56 arn = AutoRegressiveNN(p, [p])
     57 arn.to('cuda')

File ~/mambaforge/envs/x/lib/python3.11/site-packages/pyro/nn/auto_reg_nn.py:338, in AutoRegressiveNN.__init__(self, input_dim, hidden_dims, param_dims, permutation, skip_connections, nonlinearity)
    329 def __init__(
    330     self,
    331     input_dim,
   (...)
    336     nonlinearity=nn.ReLU(),
    337 ):
--> 338     super(AutoRegressiveNN, self).__init__(
    339         input_dim,
    340         0,
    341         hidden_dims,
    342         param_dims=param_dims,
    343         permutation=permutation,
    344         skip_connections=skip_connections,
    345         nonlinearity=nonlinearity,
    346     )

File ~/mambaforge/envs/x/lib/python3.11/site-packages/pyro/nn/auto_reg_nn.py:212, in ConditionalAutoRegressiveNN.__init__(self, input_dim, context_dim, hidden_dims, param_dims, permutation, skip_connections, nonlinearity)
    209 self.register_buffer("permutation", P)
    211 # Create masks
--> 212 self.masks, self.mask_skip = create_mask(
    213     input_dim=input_dim,
    214     context_dim=context_dim,
    215     hidden_dims=hidden_dims,
    216     permutation=self.permutation,
    217     output_dim_multiplier=self.output_multiplier,
    218 )
    220 # Create masked layers
    221 layers = [MaskedLinear(input_dim + context_dim, hidden_dims[0], self.masks[0])]

File ~/mambaforge/envs/x/lib/python3.11/site-packages/pyro/nn/auto_reg_nn.py:77, in create_mask(input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier)
     71 mask_skip = (output_indices.unsqueeze(-1) > input_indices.unsqueeze(0)).type_as(
     72     var_index
     73 )
     75 # Create mask from input to first hidden layer, and between subsequent hidden layers
     76 masks = [
---> 77     (hidden_indices[0].unsqueeze(-1) >= input_indices.unsqueeze(0)).type_as(
     78         var_index
     79     )
     80 ]
     81 for i in range(1, len(hidden_dims)):
     82     masks.append(
     83         (
     84             hidden_indices[i].unsqueeze(-1) >= hidden_indices[i - 1].unsqueeze(0)
     85         ).type_as(var_index)
     86     )

File ~/mambaforge/envs/x/lib/python3.11/site-packages/torch/utils/_device.py:77, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     75 if func in _device_constructors() and kwargs.get('device') is None:
     76     kwargs['device'] = self.device
---> 77 return func(*args, **kwargs)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

The guide creation is

guide = AutoGuideList(model)
guide.append(AutoGaussian(poutine.block(
    model, expose=['z'], hide=['w']
))
p = w.shape[0]
arn = AutoRegressiveNN(p, [p])
arn.to('cuda')
transform = AffineAutoregressive(arn)
transform.to('cuda')
guide.append(AutoNormalizingFlow(poutine.block(
         model, expose=['w'], hide=['z']
    ), transform))
guide.to('cuda')

I would greatly appreciate advice about how to debug this error, and hope I am doing something wrong but simple to fix.

That was kind of a deep dive. pyro/pyro/nn/auto_reg_nn.py at dev · pyro-ppl/pyro · GitHub is actually where the error is occurring. Specifically, in sample_mask_indices() the line

    indices = torch.linspace(1, input_dim, steps=hidden_dim, device="cpu").to(
        torch.Tensor().device
    )

will only create tensors on the CPU, which causes problems if, like me, you have set the default device to be the GPU. Replacing with torch.tensor(0.).device in this case yields the desired behavior. I will submit a PR later today.