SVI step update only specified parameters / minibatch elbo scaling

Hi! I have three questions related to a model I’ve implemented using Pyro and SVI (complete toy example code is included below). The observed data for this model is grouped into 2 matrices: y ([n_yfeatures, n_individuals]), and x ([total_samples, n_xfeatures]) with a nesting structure that each sample comes from a known associated individual, and each individual contributes multiple samples.

  1. I would like to do minibatch updates by subsampling y by y-features and the x matrix by samples. However, as I currently have it in the code below, the ‘global’ variational parameters are getting updated multiple times per epoch (each iteration of svi.step). Is there some way to specify certain parameters to update each iteration so that I could update “local” (zeta, gamma, sigma), and “global” (xi_g, chi, omega_b) parameters separately?

  2. Due to 1, the loss isn’t scaling correctly because of the global updates each minibatch iteration. It seems like poutine.scale might be able to fix this, but I have read through the documentation and I still am not sure how to use that correctly. Would fixing naturally fix this issue because the plates are aware of the subsampling?

  3. With this model when I try to make the number of samples per minibatch smaller (not even that small-- for example 200 samples, but in larger datasets I’ve had this issue with up to 10,000 samples), SVI consistently ends up failing with an error about invalid concentration parameters for phi. This can be recreated in my code below by changing n_minibatches to be 5 – it will then fails at epoch 229. Is there anything obvious I am missing? I’ve tried clipping the gradient but that didn’t do much. If I make the learning rate significantly smaller, it still fails just after many more completed epochs.

Thank you so much in advance for any insights you can share!!

###################################################
import numpy as np
import pandas as pd
import math

import torch
import torch.distributions as tdist
import torch.optim as optim
from torch.distributions import constraints

import pyro
import pyro.optim as poptim
import pyro.distributions as dist
from pyro.ops.indexing import Vindex
from pyro.infer import SVI, JitTrace_ELBO, TraceGraph_ELBO, Trace_ELBO

pyro.enable_validation(True)
pyro.set_rng_seed(1)

##################################################################

def model(cur_x, cur_y):

    # declare plates
    yfeat_plt = pyro.plate('yfeatures', n_yfeatures, dim=-2, subsample=cur_y)
    ind_plt = pyro.plate('inds', n_inds)
    k_b_plt = pyro.plate('k_b', k_b)
    k_phi_plt = pyro.plate('k_phi', k_phi)
    samp_plt = pyro.plate('samples', tot_samps, subsample=cur_x)

    # global
    with k_phi_plt:
        lambda_g = pyro.sample("lambda_g", dist.Dirichlet(torch.ones([n_xfeatures]))) # [k_phi, n_xfeatures]

    with yfeat_plt, k_b_plt:
        lambda_s = pyro.sample("lambda_s", dist.Beta(1., 1.)) # [n_yfeatures, k_b]

    with k_b_plt:
        omega = pyro.sample("omega", dist.Dirichlet(torch.ones([k_phi]))) # [k_b, k_phi]

    # local - individuals
    with ind_plt:
        theta_b = pyro.sample("theta_b", dist.Dirichlet(torch.ones([k_b]))) # [n_inds, k_b]    

    pi_s = torch.mm(lambda_s, torch.t(theta_b)) # [n_yfeatures, n_inds]

    with yfeat_plt:
        pyro.sample('y', dist.Binomial(2, pi_s), obs=y[cur_y]) # [n_yfeatures, n_inds]

    mu = torch.mm(theta_b, omega) # [n_inds, k_phi]

    # local - samples per individual 
    sample_subset_inds = samp_to_ind[cur_x]
    with samp_plt:
        phi = pyro.sample("phi", dist.Dirichlet(mu[sample_subset_inds]))
        pi_g = torch.mm(phi, lambda_g) # [total_samples, n_xfeatures]
        pyro.sample('x', dist.Multinomial(probs=pi_g, validate_args=False), obs=x[cur_x])

##################################################################

def guide(cur_x, cur_y):

    # declare plates
    yfeat_plt = pyro.plate('yfeatures', n_yfeatures, dim=-2, subsample=cur_y)
    ind_plt = pyro.plate('inds', n_inds)
    k_b_plt = pyro.plate('k_b', k_b)
    k_phi_plt = pyro.plate('k_phi', k_phi)
    samp_plt = pyro.plate('samples', tot_samps, subsample=cur_x)

    # global
    xi_g = pyro.param("xi_g", torch.ones([k_phi, x.shape[1]])*0.5, constraint=constraints.positive)
    with k_phi_plt:
        lambda_g = pyro.sample("lambda_g", dist.Dirichlet(xi_g)) # [k_phi, n_xfeatures]

    zeta = pyro.param("zeta", torch.ones([n_yfeatures, k_b]), constraint=constraints.positive)
    gamma = pyro.param("gamma", torch.ones([n_yfeatures, k_b]), constraint=constraints.positive)
    with yfeat_plt, k_b_plt:
        lambda_s = pyro.sample("lambda_s", dist.Beta(zeta[cur_y], gamma[cur_y])) # [n_yfeatures, k_b]

    chi = pyro.param("chi", torch.ones([k_b, k_phi])*0.5, constraint=constraints.positive)
    with k_b_plt:
        omega = pyro.sample("omega", dist.Dirichlet(chi)) # [k_b, k_phi]

    # local - individual level
    omega_b = pyro.param("omega_b", torch.ones([n_inds, k_b])*0.5, constraint=constraints.positive)
    with ind_plt:
        theta_b = pyro.sample("theta_b", dist.Dirichlet(omega_b)) # [n_inds, k_b]

    # local - samples per individual 
    sigma = pyro.param("sigma", torch.ones([tot_samps, k_phi])*0.5, constraint=constraints.positive)
    with samp_plt:
        phi = pyro.sample("phi", dist.Dirichlet(sigma[cur_x])) # [total_samples, k_phi]


##################################################################

def gen_toy_data():
    lambda_g = dist.Dirichlet(0.2 * torch.ones([k_phi, n_xfeatures])).sample()
    lambda_s = dist.Beta(torch.ones([n_yfeatures, k_b]), torch.ones([n_yfeatures, k_b])).sample()
    omega = dist.Dirichlet(torch.ones([k_b, k_phi]) * 0.2).sample()
    theta_b = dist.Dirichlet(0.2 * torch.ones([n_inds, k_b])).sample()
    mu = torch.mm(theta_b, omega)
    pi_s = torch.mm(lambda_s, torch.t(theta_b))
    y = pyro.sample('y', dist.Binomial(2, pi_s))
    # generate the number of samples per individual
    ind_samp_counts = dist.Poisson(20).sample(torch.Size([n_inds]))
    tot_samps = int(ind_samp_counts.sum())
    # to match my real data, have a mapping for which individual each sample is from
    samp_to_ind = []
    for idx, i in enumerate(ind_samp_counts.numpy()):
        samp_to_ind.extend( np.repeat(idx, int(i)) )
    phi = dist.Dirichlet(mu[samp_to_ind]).sample()
    pi_g = torch.mm(phi, lambda_g)
    x = torch.zeros([tot_samps, n_xfeatures])
    # each sample has a variable count sum, so sample from Multinomial sequentially
    for c in range(tot_samps):
        cur_sample_count = int(dist.Poisson(150).sample())
        x[c] = dist.Multinomial(cur_sample_count, pi_g[c]).sample()
    return(y, x, tot_samps, torch.tensor(samp_to_ind))


####################################################################
##################################################################

# declare model parameters
k_phi, k_b = 6, 3
n_inds = 50
n_yfeatures = 200
n_xfeatures = 200

# generate toy data
y, x, tot_samps, samp_to_ind = gen_toy_data()

# set up mb partitions
n_minibatches = 4
mb_samp_size = math.floor(tot_samps / n_minibatches)
yfeats_per_mb = int(np.ceil(n_yfeatures / n_minibatches))

# set up SVI
pyro.set_rng_seed(1)
pyro.clear_param_store()
svi = SVI(model, guide, poptim.Adam({"lr": 0.015}), loss=Trace_ELBO())
losses = []

for epoch in range(1500):
    print('EPOCH ' + str(epoch),flush=True)
    #naively ignore individuals and just choose random partitions
    x_rand_indx = torch.randperm(tot_samps)
    y_rand_indx = torch.randperm(n_yfeatures)
    elbo = 0.0
    for mb in range(n_minibatches):
        x_mb_start = (mb * mb_samp_size)
        x_mb_end = np.max([(mb +1) * mb_samp_size, tot_samps])
        y_mb_start = (mb * yfeats_per_mb)
        y_mb_end = np.max([(mb+1) * yfeats_per_mb, n_yfeatures])
        mb_elbo = svi.step(x_rand_indx[x_mb_start:x_mb_end], y_rand_indx[y_mb_start:y_mb_end])
        elbo += mb_elbo # not scaled correctly bc 'global' params updated each mb step
    losses.append(elbo)

Hi,

Since I haven’t gotten a response, I was wondering whether it was because my code example is too long to read through? I’m not sure whether anyone from Pyro dev is planning on responding or whether I should repost or update this question with a simpler model for the same issue (namely the first question about updating certain parameters and not others).

I would appreciate any feedback, thanks!

re: 1 (non-joint-optimization) are you sure you want to do this? there’s generally no advantage to doing this as you’re basically just wasting computation that’s already been done (e.g. sampling).

re: 2 how do you know the loss isn’t scaling correctly? if you’re using plate correctly it should handle all the scaling for you.

re: 3 i don’t know what’s going on but have you tried using a more restrictive constraint than just positivity? what kind of values is mu taking on when it fails?

Thank you for responding!

  1. I was thinking along the lines of: each mb iteration view the global parameters as fixed, and thus not computing their gradients at all except for once each epoch, after all minibatch local updates. Thus maybe having a different optimizer object for each set of parameters?

  2. You’re right here, sorry

  3. when it fails at least one row in mu is all NaN. I haven’t, but I can try a more restrictive constraint.

re: 1 you can do something like that using e.g. something like per_param_callable described in this tutorial, but in my experience there’s usually no reason to except that sort of thing to perform better (rather it’ll just converge more slowly)

re: 3 you might want to check that any learned parameters aren’t getting dangerously close to the edge of their allowed values

1 Like

Hi again,

re: 3, the error is coming from one row in the guide pyro.param(‘sigma’). The step right before the row all turns to NaN, it is tensor([19.1591, 0.0888, 0.0965, 0.0911, 9.1888, 0.1297], which doesn’t seem close to the edge of allowed (positive) values. I tried to do an interval constraint(0.1, 10) instead of just positive, but it still failed after 3 more iterations.

When I increase the number of minibatches, it fails sooner (earlier iterations), and with all data in one svi.step I haven’t experienced this error at all. That’s why I thought it might be tied to issue 1), but I’m really not sure where to go from here. Any more thoughts? Thanks!!

@ari your model has a lot of moving parts. can you please try to reproduce your issuer in a simpler context?

Hi, this is the simplest I can reproduce the error:

import numpy as np
import pandas as pd
import math

import torch
import torch.distributions as tdist
import torch.optim as optim
from torch.distributions import constraints

import pyro
import pyro.optim as poptim
import pyro.distributions as dist
from pyro.ops.indexing import Vindex
from pyro.infer import SVI, JitTrace_ELBO, TraceGraph_ELBO, Trace_ELBO

pyro.enable_validation(True)
pyro.set_rng_seed(1)

##################################################################

def model(cur_x):

    # declare plates
    k_phi_plt = pyro.plate('k_phi', k_phi)
    samp_plt = pyro.plate('samples', tot_samps, subsample=cur_x)

    # global
    with k_phi_plt:
        lambda_g = pyro.sample("lambda_g", dist.Dirichlet(torch.ones([n_xfeatures]))) # [k_phi, n_xfeatures]
        omega = pyro.sample("omega", dist.Dirichlet(torch.ones([k_phi]))) # [k_phi, k_phi]

    with samp_plt:
        theta_b = pyro.sample("theta_b", dist.Dirichlet(torch.ones([k_phi])))
        mu = torch.mm(theta_b, omega)
        phi = pyro.sample("phi", dist.Dirichlet(mu))
        pi_g = torch.mm(phi, lambda_g) # [total_samples, n_xfeatures]
        pyro.sample('x', dist.Multinomial(probs=pi_g, validate_args=False), obs=x[cur_x])


##################################################################

def guide(cur_x):

    # declare plates
    k_phi_plt = pyro.plate('k_phi', k_phi)
    samp_plt = pyro.plate('samples', tot_samps, subsample=cur_x)

    # global
    xi_g = pyro.param("xi_g", torch.ones([k_phi, x.shape[1]])*0.5, constraint=constraints.positive)
    chi = pyro.param("chi", torch.ones([k_phi, k_phi])*0.5, constraint=constraints.positive)
    with k_phi_plt:
        lambda_g = pyro.sample("lambda_g", dist.Dirichlet(xi_g)) # [k_phi, n_xfeatures]
        omega = pyro.sample("omega", dist.Dirichlet(chi)) # [k_phi, k_phi]

    # local - samples per individual 
    sigma = pyro.param("sigma", torch.ones([tot_samps, k_phi])*0.5, constraint=constraints.positive)
    omega_b = pyro.param("omega_b", torch.ones([tot_samps, k_phi])*0.5, constraint=constraints.positive)
    with samp_plt:
        theta_b = pyro.sample("theta_b", dist.Dirichlet(omega_b[cur_x])) # [total_samples, k_phi]
        phi = pyro.sample("phi", dist.Dirichlet(sigma[cur_x])) # [total_samples, k_phi]


##################################################################

def gen_toy_data():
    lambda_g = dist.Dirichlet(0.2 * torch.ones([k_phi, n_xfeatures])).sample()
    omega = dist.Dirichlet(torch.ones([k_phi, k_phi]) * 0.2).sample()
    # generate the number of samples per individual
    ind_samp_counts = dist.Poisson(20).sample(torch.Size([n_inds]))
    tot_samps = int(ind_samp_counts.sum())
    theta_b = dist.Dirichlet(0.2 * torch.ones([tot_samps, k_phi])).sample()
    mu = torch.mm(theta_b, omega)
    phi = dist.Dirichlet(mu).sample()
    pi_g = torch.mm(phi, lambda_g)
    x = torch.zeros([tot_samps, n_xfeatures])
    # each sample has a variable count sum, so sample from Multinomial sequentially
    for c in range(tot_samps):
        cur_sample_count = int(dist.Poisson(150).sample())
        x[c] = dist.Multinomial(cur_sample_count, pi_g[c]).sample()
    return(x, tot_samps)


####################################################################
##################################################################


# declare model parameters
k_phi, k_b = 6, 3
n_inds = 50
n_yfeatures = 200
n_xfeatures = 200

# generate toy data
x, tot_samps = gen_toy_data()

# set up mb partitions
n_minibatches = 10
mb_samp_size = math.floor(tot_samps / n_minibatches)

# set up SVI
pyro.set_rng_seed(1)
pyro.clear_param_store()
svi = SVI(model, guide, poptim.Adam({"lr": 0.01}), loss=Trace_ELBO())
losses = []

#177, index 245 person 13
for epoch in range(1000):
    print('EPOCH ' + str(epoch),flush=True)
    x_rand_indx = torch.randperm(tot_samps)
    elbo = 0.0
    for mb in range(n_minibatches):
        x_mb_start = (mb * mb_samp_size)
        x_mb_end = np.max([(mb +1) * mb_samp_size, tot_samps])
        mb_elbo = svi.step(x_rand_indx[x_mb_start:x_mb_end])
        elbo += mb_elbo / n_minibatches
    losses.append(elbo)

@ari it seems to be much more stable if you decorate both the model and guide with scale statements like

# here i use an arbitrary number but in general you might scale by 
# the total number of datapoints in a mini-batch
# times the size of each observation
@pyro.poutine.scale(scale=1.0e-6)
def model(...):

your ELBO values are pretty big so i guess you were just running into problems with numerical precision

(it’s plausible that using double precision might have helped too but i didn’t try that)

1 Like

Oh! Yes a couple weeks ago I had tried using double precision and it did delay the error for some number of iterations but not nearly as well as the poutine.scale-- thank you!

Replying here to addition questions on github:

what to do if hundreds of iterations in, part of a Pyro param is NAN

You could run under pdb so that the error opens up a debugger?

I also often sprinkle suspicious functions with Pyro’s helper warn_if_nan which also adds warnings to the backward hooks; that way you may be able to trace a nan gradient before it hits a param.

from pyro.util import warn_if_nan
...
Hinv = rinverse(H, sym=True)
warn_if_nan(Hinv, 'Hinv')

You can grep around Pyro’s codebase for other example usage, e.g. in pyro/ops/newton.py.

1 Like