Hi! I have three questions related to a model I’ve implemented using Pyro and SVI (complete toy example code is included below). The observed data for this model is grouped into 2 matrices: y ([n_yfeatures, n_individuals]), and x ([total_samples, n_xfeatures]) with a nesting structure that each sample comes from a known associated individual, and each individual contributes multiple samples.

I would like to do minibatch updates by subsampling y by yfeatures and the x matrix by samples. However, as I currently have it in the code below, the ‘global’ variational parameters are getting updated multiple times per epoch (each iteration of svi.step). Is there some way to specify certain parameters to update each iteration so that I could update “local” (zeta, gamma, sigma), and “global” (xi_g, chi, omega_b) parameters separately?

Due to 1, the loss isn’t scaling correctly because of the global updates each minibatch iteration. It seems like poutine.scale might be able to fix this, but I have read through the documentation and I still am not sure how to use that correctly. Would fixing naturally fix this issue because the plates are aware of the subsampling?

With this model when I try to make the number of samples per minibatch smaller (not even that small for example 200 samples, but in larger datasets I’ve had this issue with up to 10,000 samples), SVI consistently ends up failing with an error about invalid concentration parameters for phi. This can be recreated in my code below by changing n_minibatches to be 5 – it will then fails at epoch 229. Is there anything obvious I am missing? I’ve tried clipping the gradient but that didn’t do much. If I make the learning rate significantly smaller, it still fails just after many more completed epochs.
Thank you so much in advance for any insights you can share!!
###################################################
import numpy as np
import pandas as pd
import math
import torch
import torch.distributions as tdist
import torch.optim as optim
from torch.distributions import constraints
import pyro
import pyro.optim as poptim
import pyro.distributions as dist
from pyro.ops.indexing import Vindex
from pyro.infer import SVI, JitTrace_ELBO, TraceGraph_ELBO, Trace_ELBO
pyro.enable_validation(True)
pyro.set_rng_seed(1)
##################################################################
def model(cur_x, cur_y):
# declare plates
yfeat_plt = pyro.plate('yfeatures', n_yfeatures, dim=2, subsample=cur_y)
ind_plt = pyro.plate('inds', n_inds)
k_b_plt = pyro.plate('k_b', k_b)
k_phi_plt = pyro.plate('k_phi', k_phi)
samp_plt = pyro.plate('samples', tot_samps, subsample=cur_x)
# global
with k_phi_plt:
lambda_g = pyro.sample("lambda_g", dist.Dirichlet(torch.ones([n_xfeatures]))) # [k_phi, n_xfeatures]
with yfeat_plt, k_b_plt:
lambda_s = pyro.sample("lambda_s", dist.Beta(1., 1.)) # [n_yfeatures, k_b]
with k_b_plt:
omega = pyro.sample("omega", dist.Dirichlet(torch.ones([k_phi]))) # [k_b, k_phi]
# local  individuals
with ind_plt:
theta_b = pyro.sample("theta_b", dist.Dirichlet(torch.ones([k_b]))) # [n_inds, k_b]
pi_s = torch.mm(lambda_s, torch.t(theta_b)) # [n_yfeatures, n_inds]
with yfeat_plt:
pyro.sample('y', dist.Binomial(2, pi_s), obs=y[cur_y]) # [n_yfeatures, n_inds]
mu = torch.mm(theta_b, omega) # [n_inds, k_phi]
# local  samples per individual
sample_subset_inds = samp_to_ind[cur_x]
with samp_plt:
phi = pyro.sample("phi", dist.Dirichlet(mu[sample_subset_inds]))
pi_g = torch.mm(phi, lambda_g) # [total_samples, n_xfeatures]
pyro.sample('x', dist.Multinomial(probs=pi_g, validate_args=False), obs=x[cur_x])
##################################################################
def guide(cur_x, cur_y):
# declare plates
yfeat_plt = pyro.plate('yfeatures', n_yfeatures, dim=2, subsample=cur_y)
ind_plt = pyro.plate('inds', n_inds)
k_b_plt = pyro.plate('k_b', k_b)
k_phi_plt = pyro.plate('k_phi', k_phi)
samp_plt = pyro.plate('samples', tot_samps, subsample=cur_x)
# global
xi_g = pyro.param("xi_g", torch.ones([k_phi, x.shape[1]])*0.5, constraint=constraints.positive)
with k_phi_plt:
lambda_g = pyro.sample("lambda_g", dist.Dirichlet(xi_g)) # [k_phi, n_xfeatures]
zeta = pyro.param("zeta", torch.ones([n_yfeatures, k_b]), constraint=constraints.positive)
gamma = pyro.param("gamma", torch.ones([n_yfeatures, k_b]), constraint=constraints.positive)
with yfeat_plt, k_b_plt:
lambda_s = pyro.sample("lambda_s", dist.Beta(zeta[cur_y], gamma[cur_y])) # [n_yfeatures, k_b]
chi = pyro.param("chi", torch.ones([k_b, k_phi])*0.5, constraint=constraints.positive)
with k_b_plt:
omega = pyro.sample("omega", dist.Dirichlet(chi)) # [k_b, k_phi]
# local  individual level
omega_b = pyro.param("omega_b", torch.ones([n_inds, k_b])*0.5, constraint=constraints.positive)
with ind_plt:
theta_b = pyro.sample("theta_b", dist.Dirichlet(omega_b)) # [n_inds, k_b]
# local  samples per individual
sigma = pyro.param("sigma", torch.ones([tot_samps, k_phi])*0.5, constraint=constraints.positive)
with samp_plt:
phi = pyro.sample("phi", dist.Dirichlet(sigma[cur_x])) # [total_samples, k_phi]
##################################################################
def gen_toy_data():
lambda_g = dist.Dirichlet(0.2 * torch.ones([k_phi, n_xfeatures])).sample()
lambda_s = dist.Beta(torch.ones([n_yfeatures, k_b]), torch.ones([n_yfeatures, k_b])).sample()
omega = dist.Dirichlet(torch.ones([k_b, k_phi]) * 0.2).sample()
theta_b = dist.Dirichlet(0.2 * torch.ones([n_inds, k_b])).sample()
mu = torch.mm(theta_b, omega)
pi_s = torch.mm(lambda_s, torch.t(theta_b))
y = pyro.sample('y', dist.Binomial(2, pi_s))
# generate the number of samples per individual
ind_samp_counts = dist.Poisson(20).sample(torch.Size([n_inds]))
tot_samps = int(ind_samp_counts.sum())
# to match my real data, have a mapping for which individual each sample is from
samp_to_ind = []
for idx, i in enumerate(ind_samp_counts.numpy()):
samp_to_ind.extend( np.repeat(idx, int(i)) )
phi = dist.Dirichlet(mu[samp_to_ind]).sample()
pi_g = torch.mm(phi, lambda_g)
x = torch.zeros([tot_samps, n_xfeatures])
# each sample has a variable count sum, so sample from Multinomial sequentially
for c in range(tot_samps):
cur_sample_count = int(dist.Poisson(150).sample())
x[c] = dist.Multinomial(cur_sample_count, pi_g[c]).sample()
return(y, x, tot_samps, torch.tensor(samp_to_ind))
####################################################################
##################################################################
# declare model parameters
k_phi, k_b = 6, 3
n_inds = 50
n_yfeatures = 200
n_xfeatures = 200
# generate toy data
y, x, tot_samps, samp_to_ind = gen_toy_data()
# set up mb partitions
n_minibatches = 4
mb_samp_size = math.floor(tot_samps / n_minibatches)
yfeats_per_mb = int(np.ceil(n_yfeatures / n_minibatches))
# set up SVI
pyro.set_rng_seed(1)
pyro.clear_param_store()
svi = SVI(model, guide, poptim.Adam({"lr": 0.015}), loss=Trace_ELBO())
losses = []
for epoch in range(1500):
print('EPOCH ' + str(epoch),flush=True)
#naively ignore individuals and just choose random partitions
x_rand_indx = torch.randperm(tot_samps)
y_rand_indx = torch.randperm(n_yfeatures)
elbo = 0.0
for mb in range(n_minibatches):
x_mb_start = (mb * mb_samp_size)
x_mb_end = np.max([(mb +1) * mb_samp_size, tot_samps])
y_mb_start = (mb * yfeats_per_mb)
y_mb_end = np.max([(mb+1) * yfeats_per_mb, n_yfeatures])
mb_elbo = svi.step(x_rand_indx[x_mb_start:x_mb_end], y_rand_indx[y_mb_start:y_mb_end])
elbo += mb_elbo # not scaled correctly bc 'global' params updated each mb step
losses.append(elbo)