Hi, from previous posts, setting some of the priors to tensors on cuda should allow me to run pyro on gpu. but doesn’t seem to work for me. could you pl shelp?
import pyro
from pyro.nn import PyroSample
import torch
from torch import nn
from pyro.nn import PyroModule
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO
import pyro.distributions as dist
import torch.nn.functional as F
import torch.multiprocessing as mp
from torch.utils.data.sampler import Sampler
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(torch.tensor(0.0, device="cuda"), torch.tensor(1.0)).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(torch.tensor(0.0, device="cuda"), 10.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
sigma = pyro.sample("sigma", dist.Uniform(torch.tensor(0., device = "cuda"), 10.))
mean = self.linear(x).squeeze(-1).cuda()
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean.cuda(), sigma), obs=y)
return mean
#make x,y dataset
def make_data(N, P, mu, sd):
X = dist.Normal(0.0, 1.0).sample((N,P))
B = dist.Normal(mu, sd).sample((P,)) #define betas
eps = dist.Normal(0.0, 1.0).sample((N,))
Y = X @ B + eps
print(f"X.shape= {X.shape}; Y.shape{Y.shape}")
return X, Y
def train(model, guide, X, Y, lr=0.05, n_steps=201):
adam_params = {"lr": lr}
adam = pyro.optim.Adam(adam_params)
svi = SVI(model, guide, adam, loss=Trace_ELBO())
pyro.clear_param_store()
for j in range(n_steps):
# calculate the loss and take a gradient step
loss = svi.step(X, Y)
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(Y)))
if __name__ == "__main__":
device = torch.device("cuda")
X,Y = make_data(N=6, P=3, mu=0.0, sd=1.0)
X,Y = X.to(device), Y.to(device)
# #model
model = BayesianRegression(X.shape[1], 1)
print(f"\nmodel: {model}\n")
# guide
guide = AutoDiagonalNormal(model)
print(f"\guide: {guide}\n")
#inference with SVI
train(model, guide, X, Y)
error
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 759, in forward
self._setup_prototype(*args, **kwargs)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 938, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 636, in _setup_prototype
super()._setup_prototype(*args, **kwargs)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 157, in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 216, in get_trace
self(*args, **kwargs)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 198, in __call__
raise exc from e
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 191, in __call__
ret = self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 32, in _context_wrap
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/nn/module.py", line 450, in __call__
result = super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/data/greenbaum/users/ahunos/projects/Methyl2Expression/scripts/models/main_pyro.py", line 23, in forward
mean = self.linear(x).squeeze(-1).cuda()
^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/nn/module.py", line 450, in __call__
result = super().__call__(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/nn/module.py", line 533, in __getattr__
value = pyro.sample(fullname, prior)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/primitives.py", line 189, in sample
apply_stack(msg)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/runtime.py", line 378, in apply_stack
frame._process_message(msg)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 189, in _process_message
method(msg)
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/initialization.py", line 237, in _pyro_sample
value = self.init_fn(msg)
^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/initialization.py", line 98, in init_to_median
return fallback(site)
^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/pyro/infer/autoguide/initialization.py", line 43, in init_to_feasible
value = site["fn"].sample().detach()
^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/distributions/independent.py", line 101, in sample
return self.base_dist.sample(sample_shape)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ahunos/.conda/envs/methyl_ONT/lib/python3.11/site-packages/torch/distributions/normal.py", line 70, in sample
return torch.normal(self.loc.expand(shape), self.scale.expand(shape))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument std in method wrapper_CUDA_Tensor_Tensor_normal)
Trace Shapes:
Param Sites:
Sample Sites:
sigma dist |
value |
Trace Shapes:
Param Sites:
Sample Sites: