Pyro not working on GPU even after setting cuda tensor priors

Hi, from previous posts, setting some of the priors to tensors on cuda should allow me to run pyro on gpu. but doesn’t seem to work for me. could you pl shelp?

import pyro
from pyro.nn import PyroSample
import torch
from torch import nn
from pyro.nn import PyroModule
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist

import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data.sampler import Sampler

class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(torch.tensor(0.0, device="cuda"), torch.tensor(1.0)).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(torch.tensor(0.0, device="cuda"), 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(torch.tensor(0., device = "cuda"), 10.))
        mean = self.linear(x).squeeze(-1).cuda()
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean.cuda(), sigma), obs=y)
        return mean
#make x,y dataset
def make_data(N, P, mu, sd):
    X = dist.Normal(0.0, 1.0).sample((N,P))
    B = dist.Normal(mu, sd).sample((P,)) #define betas
    eps = dist.Normal(0.0, 1.0).sample((N,))
    Y = X @ B + eps
    print(f"X.shape= {X.shape}; Y.shape{Y.shape}")
    return X, Y

def train(model, guide, X, Y, lr=0.05, n_steps=201):
    adam_params = {"lr": lr}
    adam = pyro.optim.Adam(adam_params)
    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    pyro.clear_param_store()
    for j in range(n_steps):
        # calculate the loss and take a gradient step
        loss = svi.step(X, Y)
        if j % 100 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(Y)))

if __name__ == "__main__":
    device = torch.device("cuda")
    X,Y = make_data(N=6, P=3, mu=0.0, sd=1.0)
    X,Y = X.to(device), Y.to(device)

    # #model
    model = BayesianRegression(X.shape[1], 1)
    print(f"\nmodel: {model}\n")
    # guide
    guide = AutoDiagonalNormal(model)
    print(f"\guide: {guide}\n")
    #inference with SVI
    train(model, guide, X, Y)

error

             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 759, in forward
    self._setup_prototype(*args, **kwargs)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 938, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 636, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 157, in _setup_prototype
    self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 216, in get_trace
    self(*args, **kwargs)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 198, in __call__
    raise exc from e
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 191, in __call__
    ret = self.fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/nn/module.py", line 450, in __call__
    result = super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/greenbaum/users/ahunos/projects/Methyl2Expression/scripts/models/main_pyro.py", line 23, in forward
    mean = self.linear(x).squeeze(-1).cuda()
           ^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/nn/module.py", line 450, in __call__
    result = super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
                           ^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/nn/module.py", line 533, in __getattr__
    value = pyro.sample(fullname, prior)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/primitives.py", line 189, in sample
    apply_stack(msg)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/runtime.py", line 378, in apply_stack
    frame._process_message(msg)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 189, in _process_message
    method(msg)
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/initialization.py", line 237, in _pyro_sample
    value = self.init_fn(msg)
            ^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/initialization.py", line 98, in init_to_median
    return fallback(site)
           ^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/initialization.py", line 43, in init_to_feasible
    value = site["fn"].sample().detach()
            ^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/distributions/independent.py", line 101, in sample
    return self.base_dist.sample(sample_shape)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/distributions/normal.py", line 70, in sample
    return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument std in method wrapper_CUDA_Tensor_Tensor_normal)
Trace Shapes:  
 Param Sites:  
Sample Sites:  
   sigma dist |
        value |
Trace Shapes:
 Param Sites:
Sample Sites:

I have had success with torch.set_default_device('cuda') as a way around this. Just set at top of script before creating new tensors.

thanks for the heads up!