I’m having a similar runtime error (I don’t really know how to use GPU properly when coding with pyro):
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
Unfortunately, removing the argument from plate
still raises the same error. I also set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
This causes some shape mismatch:
---> 47 pyro.render_model(model, model_args=(data,), render_params=True)
File c:\Users\hyf\.conda\envs\torch\lib\site-packages\pyro\infer\inspect.py:584, in render_model(model, model_args, model_kwargs, filename, render_distributions, render_params)
582 # Get model relations.
583 if not isinstance(model_args, list) and not isinstance(model_kwargs, list):
--> 584 relations = [get_model_relations(model, model_args, model_kwargs)]
585 else: # semisupervised
586 if isinstance(model_args, list):
File c:\Users\hyf\.conda\envs\torch\lib\site-packages\pyro\infer\inspect.py:305, in get_model_relations(model, model_args, model_kwargs)
300 if site["type"] != "sample" or site_is_subsample(site):
301 continue
303 sample_sample[name] = [
304 upstream
--> 305 for upstream in get_provenance(site["fn"].log_prob(site["value"]))
306 if upstream != name and _get_type_from_frozenname(upstream) == "sample"
307 ]
309 sample_param[name] = [
310 upstream
311 for upstream in get_provenance(site["fn"].log_prob(site["value"]))
312 if upstream != name and _get_type_from_frozenname(upstream) == "param"
313 ]
315 sample_dist[name] = _get_dist_name(site["fn"])
File c:\Users\hyf\.conda\envs\torch\lib\site-packages\torch\distributions\independent.py:99, in Independent.log_prob(self, value)
98 def log_prob(self, value):
---> 99 log_prob = self.base_dist.log_prob(value)
100 return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)
File c:\Users\hyf\.conda\envs\torch\lib\site-packages\torch\distributions\gamma.py:76, in Gamma.log_prob(self, value)
73 if self._validate_args:
74 self._validate_sample(value)
75 return (torch.xlogy(self.concentration, self.rate) +
---> 76 torch.xlogy(self.concentration - 1, value) -
77 self.rate * value - torch.lgamma(self.concentration))
File c:\Users\hyf\.conda\envs\torch\lib\site-packages\torch\utils\_device.py:62, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
60 if func in _device_constructors() and kwargs.get('device') is None:
61 kwargs['device'] = self.device
---> 62 return func(*args, **kwargs)
RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1
It still doesn’t work. Also wondering what’s going on. How should I use pyro properly with GPU?
Here is my model.
@config_enumerate
def model_lkj(data=None, alpha=1.0, T=50, batch_size=100):
'''
Truncated mixture model at T components
component covariance being a Wishart(df=T=2, scale=a*I)
a: multiplicative const in var
'''
alpha_w = pyro.param("alpha_w", lambda:Gamma(1, 1).sample([1]), constraint=constraints.positive)
with pyro.plate("sticks", T-1, device=device):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("component", T, device=device):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), torch.eye(d)))
theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+2)).to_event(1))
omega = pyro.sample('omega', LKJCholesky(d, concentration=1.5))
Omega = torch.bmm(((1/alpha_w)*theta.sqrt()).diag_embed(), omega)
tril = pyro.deterministic("tril", torch.tril(torch.linalg.inv(Omega)), event_dim=2)
with pyro.plate("data", N, subsample_size=batch_size, device=device) as idx:
z = pyro.sample("z", Categorical(mix_weights(beta))) # , infer={'enumerate': 'parallel'}
pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=tril[z]), obs=data.index_select(0, idx))
model = model_lkj
pyro.render_model(model, model_args=(data,), render_params=True)