Is there support to run MCMC with pyro models on GPUs? I have my data tensor on a GPU, but it seems to be using CPU regardless:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def extended_dirichlet_mixture_model(data, num_components=3, num_models=3):
data = data.to(torch.float32)
# Global variables.
weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
# print("weights device: ", weights.device)
# Hierarchical prior for alphas
with pyro.plate('components', num_components):
alpha_base = pyro.sample('alpha_base', dist.Gamma(1.0, 1.0).expand([data.shape[-1]]).to_event(1))
# print(alpha_base.shape, alpha_base)
# print(alpha_base.device)
# Separate alphas for each model within each component
alphas = torch.empty(num_components, num_models, data.shape[-1])
# alphas = alphas.to(device)
for i in range(num_models):
alphas[:, i, :] = pyro.sample(f'alphas_{i}', dist.Gamma(alpha_base, 1.0).to_event(1))
print(data.device, alphas.device) # alphas is on cpu, but data is on GPU!
with pyro.plate('data', len(data)):
for i in range(num_models):
# Local variables for each model's prediction.
assignment = pyro.sample(f'assignment_{i}', dist.Categorical(weights))
pyro.sample(f'obs_{i}', dist.Dirichlet(alphas[assignment, i, :]), obs=data[:, i, :])
# Example data: replace this with your actual data
N = 100 # Number of data points
data = torch.rand(N, 3, 4) # Each data point is a 3x4 matrix
# MCMC settings
num_samples = 500
warmup_steps = 200
# Using the NUTS sampler
nuts_kernel = NUTS(extended_dirichlet_mixture_model)
data = data.to(device)
# Initialize and run MCMC
mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
mcmc.run(data)
Do I have to manually move everything to GPU (even then I am getting errors that some tensors are on CPU and others on GPU)? Is there no pyro analogue for model = model.to(device)?