import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import constraints
import pyro
from pyro.distributions import *
from pyro.infer import Predictive, SVI, Trace_ELBO
from pyro.optim import Adam
assert pyro.__version__.startswith('1.8.5')
pyro.set_rng_seed(0)
M = df.values.shape[1]
N = df.values.shape[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
# Prepare your data
data = df
# Center and scale the data
scaler = StandardScaler()
scaled_data = scaler.fit_transform(data)
# Apply PCA and obtain explained variances
pca = PCA()
principal_components = pca.fit_transform(scaled_data)
data = torch.from_numpy(principal_components[:,0:20])
data.to(device)
M=20
def mix_weights(beta):
beta1m_cumprod = (1 - beta).cumprod(-1)
return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)
def model(data):
with pyro.plate("beta_plate", T-1,device=device):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("mu_plate", T,device=device):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(M,device=device), 5 * torch.eye(M,device=device)))
with pyro.plate("data", N,device=device):
z = pyro.sample("z", Categorical(mix_weights(beta)))
pyro.sample("obs", MultivariateNormal(mu[z], torch.eye(M,device=device)), obs=data)
def guide(data):
kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0.0,device=device), torch.tensor(2.0,device=device)).sample([T-1]), constraint=constraints.positive)
tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(M, device=device), 3 * torch.eye(M,device=device)).sample([T]))
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T,device=device)).sample([N]), constraint=constraints.simplex)
with pyro.plate("beta_plate", T-1, device=device):
q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa))
with pyro.plate("mu_plate", T, device=device):
q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(M,device=device)))
with pyro.plate("data", N, device=device):
z = pyro.sample("z", Categorical(phi))
T = 5
auto_guide = pyro.infer.autoguide.AutoNormal(pyro.poutine.block(model, hide=['z']))
optim = Adam({"lr": 0.05})
svi = SVI(model, auto_guide, optim, loss=Trace_ELBO())
losses = []
def train(num_iterations, data):
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
loss = svi.step(data)
losses.append(loss)
def truncate(alpha, centers, weights):
threshold = alpha**-1 / 100.
true_centers = centers[weights > threshold]
true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
return true_centers, true_weights
alpha = torch.tensor([0.1],device=device)
train(1000, data)
# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))
alpha = 1.5
train(1000)
and the error
0%| | 0/1000 [00:00<?, ?it/s]
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[36], line 49
46 return true_centers, true_weights
48 alpha = torch.tensor([0.1],device=device)
---> 49 train(1000, data)
52 # We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
53 Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))
Cell In[36], line 39, in train(num_iterations, data)
37 pyro.clear_param_store()
38 for j in tqdm(range(num_iterations)):
---> 39 loss = svi.step(data)
40 losses.append(loss)
File /opt/conda/lib/python3.10/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File /opt/conda/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:140, in Trace_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
138 loss = 0.0
139 # grab a trace from the generator
--> 140 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
141 loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
142 model_trace, guide_trace
143 )
144 loss += loss_particle / self.num_particles
File /opt/conda/lib/python3.10/site-packages/pyro/infer/elbo.py:237, in ELBO._get_traces(self, model, guide, args, kwargs)
235 else:
236 for i in range(self.num_particles):
--> 237 yield self._get_trace(model, guide, args, kwargs)
File /opt/conda/lib/python3.10/site-packages/pyro/infer/trace_elbo.py:57, in Trace_ELBO._get_trace(self, model, guide, args, kwargs)
52 def _get_trace(self, model, guide, args, kwargs):
53 """
54 Returns a single trace from the guide, and the model that is run
55 against it.
56 """
---> 57 model_trace, guide_trace = get_importance_trace(
58 "flat", self.max_plate_nesting, model, guide, args, kwargs
59 )
60 if is_validation_enabled():
61 check_if_enumerated(guide_trace)
File /opt/conda/lib/python3.10/site-packages/pyro/infer/enum.py:75, in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
72 guide_trace = prune_subsample_sites(guide_trace)
73 model_trace = prune_subsample_sites(model_trace)
---> 75 model_trace.compute_log_prob()
76 guide_trace.compute_score_parts()
77 if is_validation_enabled():
File /opt/conda/lib/python3.10/site-packages/pyro/poutine/trace_struct.py:230, in Trace.compute_log_prob(self, site_filter)
228 if "log_prob" not in site:
229 try:
--> 230 log_p = site["fn"].log_prob(
231 site["value"], *site["args"], **site["kwargs"]
232 )
233 except ValueError as e:
234 _, exc_value, traceback = sys.exc_info()
File /opt/conda/lib/python3.10/site-packages/torch/distributions/multivariate_normal.py:215, in MultivariateNormal.log_prob(self, value)
213 if self._validate_args:
214 self._validate_sample(value)
--> 215 diff = value - self.loc
216 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
217 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
File /opt/conda/lib/python3.10/site-packages/torch/utils/_device.py:62, in DeviceContext.__torch_function__(self, func, types, args, kwargs)
60 if func in _device_constructors() and kwargs.get('device') is None:
61 kwargs['device'] = self.device
---> 62 return func(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
pyro.param("phi").detach().numpy().shape