Hi,
my first post here. First of all thank you and your team for the great contribution with Pyro!
I modified the code in (DPMM Tutorial) as follows:
- the code runs on GPU,
- more than a single particle can be used during training (with num_particles=10),
- TraceEnum_ELBO is employed instead of Trace_ELBO.
My questions are:
-
The way I indexed the variables in model’s data plate is not elegant (and it does not work, if num_particles=1). I am aware that more elegant ways exist with the use of vindex. I could not come with a more elegant solution myself. I appreciate here any suggestions.
-
I am aware that I cannot use mini-batching (Pytorch’s with DataLoader(shuffle=True)) during each each training iteration since the z’s are locally assigned to an appropriate x instance. Subsampling the whole dataset and using plate’s indices (with subsample_size) is not option in my use-case, since the whole dataset does not necessarily fit into GPU memory. Is there any solution/example/tutorial related to this issue?
-
I observe time-to-time mode-collapsing issues with a bigger dataset by just increasing num_particles parameter. I tried to hinder these issues with smart-initialization within guide:
tau = pyro.param('tau', init_tensor=centers_init)
and reducing the variation around my initialization
q_mu = pyro.sample("mu", MultivariateNormal(tau, 0.1*torch.eye(subspace_dim, device=device)))
Should I respectively make my prior more informative in this sense, e.g. within model where mu is sampled? Any suggestion in this respect?
Code:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import torch.nn.functional as F
from torch.distributions import constraints
import pyro
from pyro.distributions import *
from pyro.infer import SVI, config_enumerate, TraceEnum_ELBO
from pyro.optim import Adam
assert pyro.__version__.startswith('1.8.6')
pyro.set_rng_seed(0)
device = torch.device("cuda")
data = torch.cat((MultivariateNormal(-8 * torch.ones(2), torch.eye(2)).sample([50]),
MultivariateNormal(8 * torch.ones(2), torch.eye(2)).sample([50]),
MultivariateNormal(torch.tensor([1.5, 2]), torch.eye(2)).sample([50]),
MultivariateNormal(torch.tensor([-0.5, 1]), torch.eye(2)).sample([50])))
data = data.to(device)
# plt.scatter(data[:, 0], data[:, 1])
# plt.title("Data Samples from Mixture of 4 Gaussians")
# plt.show()
N = data.shape[0]
num_particles=10
########################################
def mix_weights(beta):
beta1m_cumprod = (1 - beta).cumprod(-1)
return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)
########################################
def model(data):
with pyro.plate("beta_plate", T-1, device=device):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("mu_plate", T, device=device):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(2, device=device), 5 * torch.eye(2, device=device)))
with pyro.plate("data", N, device=device):
#dim=-4 is an enumeration dim for z (e.g. T = 6 clusters)
#dim=-3 is a particle vectorization (e.g. num_particles = 10 particles)
#dim=-2 is allocated for "data" plate (1 value broadcasted over a batch)
#dim=-1 is allocated as event dimension (2 values)
z = pyro.sample("z", Categorical(mix_weights(beta).unsqueeze(-2)))
pyro.sample("obs", MultivariateNormal(mu[torch.arange(num_particles).reshape(num_particles, 1), z, :],
torch.eye(2, device=device)), obs=data)
########################################
def guide(data):
kappa = pyro.param('kappa', lambda: Uniform(torch.tensor(0., device=device), torch.tensor(2., device=device)).sample([T-1]),
constraint=constraints.positive)
tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(2, device=device), 3 * torch.eye(2, device=device)).sample([T]))
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T, device=device)).sample([N]), constraint=constraints.simplex)
with pyro.plate("beta_plate", T-1, device=device):
q_beta = pyro.sample("beta", Beta(torch.ones(T-1, device=device), kappa))
with pyro.plate("mu_plate", T, device=device):
q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(2, device=device)))
with pyro.plate("data", N, device=device):
z = pyro.sample("z", Categorical(phi))
T = 6
optim = Adam({"lr": 0.05})
svi = SVI(model, config_enumerate(guide, 'parallel'), optim, loss=TraceEnum_ELBO(max_plate_nesting=1,
num_particles=num_particles,
vectorize_particles=True))
losses = []
def train(num_iterations):
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
loss = svi.step(data)
losses.append(loss)
def truncate(alpha, centers, weights):
threshold = alpha**-1 / 100.
true_centers = centers[weights > threshold]
true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
return true_centers, true_weights
alpha = torch.tensor([0.1], device=device)
train(1000)
# We make a point-estimate of our model parameters using the posterior means of tau and phi for the centers and weights
Bayes_Centers_01, Bayes_Weights_01 = truncate(alpha, pyro.param("tau").detach(), torch.mean(pyro.param("phi").detach(), dim=0))
plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.scatter(data[:, 0].detach().cpu().numpy(), data[:, 1].detach().cpu().numpy(), color="blue")
plt.scatter(Bayes_Centers_01[:, 0].detach().cpu().numpy(), Bayes_Centers_01[:, 1].detach().cpu().numpy(), color="red")
plt.show()