Problem with Tutorial Dirichlet process on GPU

Hi all,

I am trying to run the tutorial on Dirichlet process mixture on my GPU, but I end up having issues with tensors not in the same device, at this point I have no clue where am I wrong, here is my modified code for a data set of size N x M


def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)


def model(data):
    with pyro.plate("beta_plate", T-1,device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("mu_plate", T,device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(M,device=device), 5 * torch.eye(M,device=device)))

    with pyro.plate("data", N,device=device):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(M,device=device)), obs=data)

def guide(data):
    kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0.0,device=device), torch.tensor(2.0,device=device)).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(M, device=device), 3 * torch.eye(M,device=device)).sample([T]))
    phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T,device=device)).sample([N]), constraint=constraints.simplex)

    with pyro.plate("beta_plate", T-1, device=device):
        q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa))

    with pyro.plate("mu_plate", T, device=device):
        q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(M,device=device)))

    with pyro.plate("data", N, device=device):
        z = pyro.sample("z", Categorical(phi))

T = 5
auto_guide = pyro.infer.autoguide.AutoNormal(pyro.poutine.block(model, hide=['z']))
auto_guide.to(device)
optim = Adam({"lr": 0.05})
svi = SVI(model, auto_guide, optim, loss=Trace_ELBO())
losses = []

def train(num_iterations, data):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

alpha = torch.tensor([0.1],device=device)
train(1000, data)

I added an autoguide just to see if the problem was on the guide or the model, and the problem seems to be on the model, any suggestions?

Thank you and best regards

Did you put the data variable on GPU?

Thank you for replying @pavleb, yes I sent the data to the GPU.

Can you provide you complete script? Here is mine MWE, based on your code and data from the tutorial and I have no issues with GPU:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam

# assert pyro.__version__.startswith('1.8.5')
pyro.set_rng_seed(0)
device = torch.device("cuda:0")

data = torch.cat((MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]),
                  MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]),
                  MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50])))


data = data.to(device)

N = data.shape[0]
M = data.shape[1]
def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)


def model(data):
    with pyro.plate("beta_plate", T-1,device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("mu_plate", T,device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(M,device=device), 5 * torch.eye(M,device=device)))

    with pyro.plate("data", N,device=device):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(M,device=device)), obs=data)

def guide(data):
    kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0.0,device=device), torch.tensor(2.0,device=device)).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(M, device=device), 3 * torch.eye(M,device=device)).sample([T]))
    phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T,device=device)).sample([N]), constraint=constraints.simplex)

    with pyro.plate("beta_plate", T-1, device=device):
        q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa))

    with pyro.plate("mu_plate", T, device=device):
        q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(M,device=device)))

    with pyro.plate("data", N, device=device):
        z = pyro.sample("z", Categorical(phi))

T = 5
# auto_guide = pyro.infer.autoguide.AutoNormal(pyro.poutine.block(model, hide=['z']))
# auto_guide.to(device)
optim = Adam({"lr": 0.05})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
losses = []

def train(num_iterations, data):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

alpha = torch.tensor([0.1],device=device)
train(1000, data)
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.distributions import constraints

import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam

assert pyro.__version__.startswith('1.8.5')
pyro.set_rng_seed(0)
M = df.values.shape[1]
N = df.values.shape[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_default_device(device)

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
# Prepare your data
data = df

# Center and scale the data
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)

# Apply PCA and obtain explained variances
pca = PCA()
principal_components = pca.fit_transform(scaled_data)

data = torch.from_numpy(principal_components[:,0:20])
data.to(device)
M=20

def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)


def model(data):
    with pyro.plate("beta_plate", T-1,device=device):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("mu_plate", T,device=device):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(M,device=device), 5 * torch.eye(M,device=device)))

    with pyro.plate("data", N,device=device):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(M,device=device)), obs=data)

def guide(data):
    kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0.0,device=device), torch.tensor(2.0,device=device)).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(M, device=device), 3 * torch.eye(M,device=device)).sample([T]))
    phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T,device=device)).sample([N]), constraint=constraints.simplex)

    with pyro.plate("beta_plate", T-1, device=device):
        q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa))

    with pyro.plate("mu_plate", T, device=device):
        q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(M,device=device)))

    with pyro.plate("data", N, device=device):
        z = pyro.sample("z", Categorical(phi))

T = 5
auto_guide = pyro.infer.autoguide.AutoNormal(pyro.poutine.block(model, hide=['z']))
optim = Adam({"lr": 0.05})
svi = SVI(model, auto_guide, optim, loss=Trace_ELBO())
losses = []

def train(num_iterations, data):
    pyro.clear_param_store()
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)
        losses.append(loss)

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

alpha = torch.tensor([0.1],device=device)
train(1000, data)


# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))

alpha = 1.5
train(1000)

and the error

 0%|                                                  | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[36], line 49
     46     return true_centers, true_weights
     48 alpha = torch.tensor([0.1],device=device)
---> 49 train(1000, data)
     52 # We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
     53 Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))

Cell In[36], line 39, in train(num_iterations, data)
     37 pyro.clear_param_store()
     38 for j in tqdm(range(num_iterations)):
---> 39     loss = svi.step(data)
     40     losses.append(loss)

File /opt/conda/lib/python3.10/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File /opt/conda/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    138 loss = 0.0
    139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141     loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142         model_trace, guide_trace
    143     )
    144     loss += loss_particle / self.num_particles

File /opt/conda/lib/python3.10/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs)
    235 else:
    236     for i in range(self.num_particles):
--> 237         yield self._get_trace(model, guide, args, kwargs)

File /opt/conda/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
     52 def _get_trace(self, model, guide, args, kwargs):
     53     """
     54     Returns a single trace from the guide, and the model that is run
     55     against it.
     56     """
---> 57     model_trace, guide_trace = get_importance_trace(
     58         "flat", self.max_plate_nesting, model, guide, args, kwargs
     59     )
     60     if is_validation_enabled():
     61         check_if_enumerated(guide_trace)

File /opt/conda/lib/python3.10/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     72 guide_trace = prune_subsample_sites(guide_trace)
     73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
     76 guide_trace.compute_score_parts()
     77 if is_validation_enabled():

File /opt/conda/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
    228 if "log_prob" not in site:
    229     try:
--> 230         log_p = site["fn"].log_prob(
    231             site["value"], *site["args"], **site["kwargs"]
    232         )
    233     except ValueError as e:
    234         _, exc_value, traceback = sys.exc_info()

File /opt/conda/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:215, in MultivariateNormal.log_prob(self, value)
    213 if self._validate_args:
    214     self._validate_sample(value)
--> 215 diff = value - self.loc
    216 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
    217 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_device.py:62, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
     60 if func in _device_constructors() and kwargs.get('device') is None:
     61     kwargs['device'] = self.device
---> 62 return func(*args, **kwargs)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
pyro.param("phi").detach().numpy().shape


Found the error, I was not properly doing

data= data.to(device)

Sorry and thanks for your patience

I see, the problematic line is

data.to(device)

It gives you a copy. Just to get a view what is happening, do the following:

print(data.device) # Should reutrn device(type='cpu')
data.to(device) # prints out the Tensor on the GPU, but the variable data is not changes
print(data.device) # device(type='cpu'), you were expecting GPU and it is still CPU

Change that line to:

data = data.to(device)

Check the note here: https://pytorch.org/docs/stable/generated/torch.Tensor.to.html

I hope this helps.

1 Like