Running Pyro.plate on GPU

I am having a problem in running my model on GPU, its alredy running perfectly on CPU perfectly


   ` def forward(self,x,y=None):
        if y!=None:
            y=F.one_hot(y,2)
        prob= F.softmax(x)
        with pyro.plate("data",self.len_train,subsample=x,device=device):
            p=pyro.sample("p",dist.LogisticNormal(prob[:,-1].reshape(-1,1),1))
            obs= pyro.sample("obs",dist.Bernoulli(p[:,0].reshape(-1,1)).to_event(1), obs= y.float())
        return obs`

When I run on GPU I get the following error:

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking arugment for argument index in method wrapper_index_select)

Given that the model runs perfectly on GPU when I remove self.len_train from pyro.plate, which describes that the inputed batch is a part of a longer dataset on length ‘len(train)’.
What could be the problem???

I’m having a similar runtime error (I don’t really know how to use GPU properly when coding with pyro):

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

Unfortunately, removing the argument from plate still raises the same error. I also set

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

This causes some shape mismatch:

---> 47 pyro.render_model(model, model_args=(data,), render_params=True)

File c:\Users\hyf\.conda\envs\torch\lib\site-packages\pyro\infer\inspect.py:584, in render_model(model, model_args, model_kwargs, filename, render_distributions, render_params)
    582 # Get model relations.
    583 if not isinstance(model_args, list) and not isinstance(model_kwargs, list):
--> 584     relations = [get_model_relations(model, model_args, model_kwargs)]
    585 else:  # semisupervised
    586     if isinstance(model_args, list):

File c:\Users\hyf\.conda\envs\torch\lib\site-packages\pyro\infer\inspect.py:305, in get_model_relations(model, model_args, model_kwargs)
    300 if site["type"] != "sample" or site_is_subsample(site):
    301     continue
    303 sample_sample[name] = [
    304     upstream
--> 305     for upstream in get_provenance(site["fn"].log_prob(site["value"]))
    306     if upstream != name and _get_type_from_frozenname(upstream) == "sample"
    307 ]
    309 sample_param[name] = [
    310     upstream
    311     for upstream in get_provenance(site["fn"].log_prob(site["value"]))
    312     if upstream != name and _get_type_from_frozenname(upstream) == "param"
    313 ]
    315 sample_dist[name] = _get_dist_name(site["fn"])

File c:\Users\hyf\.conda\envs\torch\lib\site-packages\torch\distributions\independent.py:99, in Independent.log_prob(self, value)
     98 def log_prob(self, value):
---> 99     log_prob = self.base_dist.log_prob(value)
    100     return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

File c:\Users\hyf\.conda\envs\torch\lib\site-packages\torch\distributions\gamma.py:76, in Gamma.log_prob(self, value)
     73 if self._validate_args:
     74     self._validate_sample(value)
     75 return (torch.xlogy(self.concentration, self.rate) +
---> 76         torch.xlogy(self.concentration - 1, value) -
     77         self.rate * value - torch.lgamma(self.concentration))

File c:\Users\hyf\.conda\envs\torch\lib\site-packages\torch\utils\_device.py:62, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     60 if func in _device_constructors() and kwargs.get('device') is None:
     61     kwargs['device'] = self.device
---> 62 return func(*args, **kwargs)
RuntimeError: The size of tensor a (2) must match the size of tensor b (0) at non-singleton dimension 1

It still doesn’t work. Also wondering what’s going on. How should I use pyro properly with GPU?

Here is my model.

@config_enumerate 
def model_lkj(data=None, alpha=1.0, T=50, batch_size=100):
    '''
    Truncated mixture model at T components
    component covariance being a Wishart(df=T=2, scale=a*I)
    a: multiplicative const in var
    '''
    
    alpha_w = pyro.param("alpha_w", lambda:Gamma(1, 1).sample([1]), constraint=constraints.positive)
    with pyro.plate("sticks", T-1, device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("component", T, device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), torch.eye(d)))
        theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+2)).to_event(1))
        omega = pyro.sample('omega', LKJCholesky(d, concentration=1.5))
        Omega = torch.bmm(((1/alpha_w)*theta.sqrt()).diag_embed(), omega)
        tril = pyro.deterministic("tril", torch.tril(torch.linalg.inv(Omega)), event_dim=2)

    with pyro.plate("data", N, subsample_size=batch_size, device=device) as idx:
        z = pyro.sample("z", Categorical(mix_weights(beta))) # , infer={'enumerate': 'parallel'}
        pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=tril[z]), obs=data.index_select(0, idx))

model = model_lkj
pyro.render_model(model, model_args=(data,), render_params=True)

Hi @Maryam_Osama. Are x and y on cuda device? How do you change them to cuda? Can you show on which line in your code you get this error? Or give a reproducible code?

@Evan can you provide a full reproducible code please?

this may have been fixed in this PR: Bug fix about device of subsample in subsample_messenger.py by hjnnjh · Pull Request #3195 · pyro-ppl/pyro · GitHub

Sure! Here comes the code:

Device related

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
device

Data generation

def load_dim_train(n=2, w1=0.01, w2=1, plot=False):
    '''
    load demo data of 2d
    training with different number of clusters
    indicate development trend from overall-similarity to dimensional-identity
    ---
    params:
        n: int in {2,4,8}, number of clusters along each dimension
        w1, w2: float (w1 < w2), dimension weights
        plot: bool, whether to show a 2d plot of the data
        code related with plot removed as it is not relavent
    '''
    cov_diag1 = torch.ones(2) * w1
    cov_diag2 = torch.ones(2) * w1
    cov_diag1[0] = w2
    cov_diag2[-1] = w2
    cov1 = torch.diag(cov_diag1) # -
    cov2 = torch.diag(cov_diag2) # |

    if n == 2:
        loc1 = -5.0
        loc2 = 5.0
        data = torch.cat((
            MultivariateNormal(torch.tensor([0, loc1]), cov1).sample([100]),
            MultivariateNormal(torch.tensor([0, loc2]), cov1).sample([100]),
            MultivariateNormal(torch.tensor([loc1, 0]), cov2).sample([100]),
            MultivariateNormal(torch.tensor([loc2, 0]), cov2).sample([100]),
        ))
    return data

data = load_dim_train()

Model

T=50
@config_enumerate 
def model_lkj(data=None, alpha=1.0, batch_size=100):
    '''
    Truncated mixture model at T components
    component covariance being a Wishart(df=T=2, scale=a*I)
    a: multiplicative const in var
    '''
    
    alpha_w = pyro.param("alpha_w", lambda:Gamma(1, 1).sample([1]), constraint=constraints.positive)
    with pyro.plate("sticks", T-1, device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("component", T, device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), torch.eye(d)))
        theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+2)).to_event(1))
        omega = pyro.sample('omega', LKJCholesky(d, concentration=1.5))
        Omega = torch.bmm(((1/alpha_w)*theta.sqrt()).diag_embed(), omega)
        tril = pyro.deterministic("tril", torch.tril(torch.linalg.inv(Omega)), event_dim=2)

    with pyro.plate("data", N, subsample_size=batch_size, device=device) as idx:
        z = pyro.sample("z", Categorical(mix_weights(beta))) # , infer={'enumerate': 'parallel'}
        pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=tril[z]), obs=data.index_select(0, idx))

model = model_lkj
pyro.render_model(model, model_args=(data,), render_params=True)

I’m not so sure whether my problem relate with the bug mentioned here, but maybe I should update my pyro? Since I installed it a month ago and this was merged some two weeks ago lol :rofl:

@Evan can you provide the full code? There are imports are missing and d is not defined.

Try to install Pyro from its dev branch, in which my PR has been merged. :wink:

Sorry for the late reply! Here’s my code

import pickle

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt
import logging
from tqdm import tqdm
from pprint import pprint

import pyro
from pyro.distributions import *
import pyro.distributions.constraints as constraints
from pyro.infer import Predictive, SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete
from pyro.contrib.easyguide import *
from pyro.optim import Adam, ClippedAdam
import pyro.poutine as poutine
from pyro.infer.mcmc import NUTS, HMC
from pyro.infer.mcmc.api import MCMC


# local scripts
from data import *

RANDOM_SEED=42
pyro.set_rng_seed(RANDOM_SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

def load_dim_train(n=2, w1=0.01, w2=1, plot=False):
    '''
    load demo data of 2d
    training with different number of clusters
    indicate development trend from overall-similarity to dimensional-identity
    ---
    params:
        n: int in {2,4,8}, number of clusters along each dimension
        w1, w2: float (w1 < w2), dimension weights
        plot: bool, whether to show a 2d plot of the data
        code related with plot removed as it is not relavent
    '''
    cov_diag1 = torch.ones(2) * w1
    cov_diag2 = torch.ones(2) * w1
    cov_diag1[0] = w2
    cov_diag2[-1] = w2
    cov1 = torch.diag(cov_diag1) # -
    cov2 = torch.diag(cov_diag2) # |

    if n == 2:
        loc1 = -5.0
        loc2 = 5.0
        data = torch.cat((
            MultivariateNormal(torch.tensor([0, loc1]), cov1).sample([100]),
            MultivariateNormal(torch.tensor([0, loc2]), cov1).sample([100]),
            MultivariateNormal(torch.tensor([loc1, 0]), cov2).sample([100]),
            MultivariateNormal(torch.tensor([loc2, 0]), cov2).sample([100]),
        ))
    return data
data = load_dim_train()
data = data.to(device)
# pre processing
N=data.shape[0] # sample size
d=data.shape[1]

# DPMM Model 
T_init = 50

def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

@config_enumerate 
def model_lkj(data=None, alpha=1.0, T=T_init, batch_size=100):
    '''
    Truncated mixture model at T components
    component covariance being a Wishart(df=T=2, scale=a*I)
    a: multiplicative const in var
    '''
    
    alpha_w = pyro.param("alpha_w", lambda:Gamma(1, 1).sample([1]).to(device), constraint=constraints.positive)
    with pyro.plate("sticks", T-1, device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("component", T, device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), torch.eye(d)))
        theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+2)).to_event(1))
        omega = pyro.sample('omega', LKJCholesky(d, concentration=1.5))
        Omega = torch.bmm(((1/alpha_w)*theta.sqrt()).diag_embed(), omega)
        tril = pyro.deterministic("tril", torch.tril(torch.linalg.inv(Omega)), event_dim=2)

    with pyro.plate("data", N, subsample_size=batch_size, device=device) as idx:
        z = pyro.sample("z", Categorical(mix_weights(beta))) # , infer={'enumerate': 'parallel'}
        #pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=Omega[z]), obs=data)
        pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=tril[z]), obs=data.index_select(0, idx))

model = model_lkj
pyro.render_model(model, model_args=(data,), render_params=True)

Let me have a try :smiley: thanks!

@Evan it seems like there is a bug related to using torch.set_default_device(torch.device("cuda")) and ProvenanceTensor. I opened an issue about it `ProvenanceTensor` bug when used with `torch.set_default_device` · Issue #3218 · pyro-ppl/pyro · GitHub.

Meanwhile, can you try to run your code on cpu?

Thanks!
And yes, I tried cpu first and it worked well (on a toy dataset). However, as I’m trying to scale up to larger datasets, I thought using gpu would help me speed up the tensor calculations. (I checked out and found that most of the time is spent on checking positive-definiteness of matrices)

once you’ve got your model setup and are confident you’re doing what you want you can get speed-ups by disabling validation, which skips some of these (possibly expensive) checks:

pyro.enable_validation(False)

1 Like

@Evan the newest Pyro 1.8.5 release has this issue fixed.

Oh great, thanks for the tip!

Hooray! Thanks a lot :partying_face: