Pyro MCMC on GPU?

Is there support to run MCMC with pyro models on GPUs? I have my data tensor on a GPU, but it seems to be using CPU regardless:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def extended_dirichlet_mixture_model(data, num_components=3, num_models=3):
    data = data.to(torch.float32)
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(num_components)))
    # print("weights device: ", weights.device)

    # Hierarchical prior for alphas
    with pyro.plate('components', num_components):
        alpha_base = pyro.sample('alpha_base', dist.Gamma(1.0, 1.0).expand([data.shape[-1]]).to_event(1))
        # print(alpha_base.shape, alpha_base)
        # print(alpha_base.device)

        # Separate alphas for each model within each component
        alphas = torch.empty(num_components, num_models, data.shape[-1])
        # alphas = alphas.to(device)
        for i in range(num_models):
            alphas[:, i, :] = pyro.sample(f'alphas_{i}', dist.Gamma(alpha_base, 1.0).to_event(1))

    print(data.device, alphas.device) # alphas is on cpu, but data is on GPU! 
    with pyro.plate('data', len(data)):
        for i in range(num_models):
            # Local variables for each model's prediction.
            assignment = pyro.sample(f'assignment_{i}', dist.Categorical(weights))
            pyro.sample(f'obs_{i}', dist.Dirichlet(alphas[assignment, i, :]), obs=data[:, i, :])

# Example data: replace this with your actual data
N = 100  # Number of data points
data = torch.rand(N, 3, 4)  # Each data point is a 3x4 matrix

# MCMC settings
num_samples = 500
warmup_steps = 200

# Using the NUTS sampler
nuts_kernel = NUTS(extended_dirichlet_mixture_model)

data = data.to(device)

# Initialize and run MCMC
mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
mcmc.run(data)

Do I have to manually move everything to GPU (even then I am getting errors that some tensors are on CPU and others on GPU)? Is there no pyro analogue for model = model.to(device)?

you either need to take care to instantiate distributions with parameters defined on the target device (e.g. torch.ones(num_components, device=...)) (and similarly for pyro.param) or use API like torch.set_default_tensor_type(torch.cuda.DoubleTensor). if you’re using PyroModule then you can use to but you’re not using PyroModule

Even when I change all of the parameters to be on the target device, I still get issues that some tensors are on CPU. Would you recommend using PyroModule instead?

you’re probably just missing something, e.g. in your empty statement. triple check each line fo code first. probably no need for PyroModule for a simple model.