How to implement one-vs-each approximation for big softmax

Hi, I’m new to pyro and struggling to figure out whether pyro is appropriate to replicate Francisco J. R. Ruiz, Susan Athey and David M. Blei’s paper (https://arxiv.org/pdf/1711.03560.pdf) ‘SHOPPER: A PROBABILISTIC MODEL OF CONSUMER CHOICE WITH SUBSTITUTES AND COMPLEMENTS’.

Now I’m stuck in one specific question about the inference of the model, the one-vs-each bound to approximate a big softmax. (page 31)

where eta is at page 10 (Sorry, new user is only allowed to post one figure, so I combine it to the right.)

Although within this context, I believe what is important is how to do this in pyro?
(the one-vs-each bound is introduced in paper ‘One-vs-each approximation to softmax for scalable estimation of probabilities’ https://arxiv.org/pdf/1609.07410.pdf)

Usually we will do:

pyro.sample(“eta”, dist.Categorical(logits), obs=data)

but with one-vs-each bound, we may do something like (I’m not sure):

pyro.sample(“eta”, dist.Binomial(logits), obs=data)

So, what should we use to scale the negative-sampled loss?
Should it be with pyro.plate()? Or shall we customize the ELBO?

I hope I make myself understood. Any help is appreciated~

Hi Andrew, did you end up figuring this out?

the main complications in doing this arguably lie in doing the relevant pytorch tensor indexing gymnastics. here’s an example that demonstrates that the lower bound is in fact a lower bound for a single categorical datapoint. generalizing this should be pretty straightforward (if tedious).

import numpy as np
import torch

import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO

def onehot_model(onehot_data, onehot_logits):
    pyro.sample('data', dist.OneHotCategorical(logits=onehot_logits), obs=onehot_data)

def bernoulli_model(data, oneversus_logits):
    pyro.sample('data', dist.Bernoulli(logits=oneversus_logits), obs=data)

def subsample_model(data, oneversus_logits):
    with pyro.plate("data_plate", oneversus_logits.size(0), subsample_size=1) as idx:
        pyro.sample('data', dist.Bernoulli(logits=oneversus_logits[idx]), obs=data[idx])

def guide(*args):
    pass

K = 4
onehot_logits = torch.randn(K)
data = torch.tensor([1])  # define single data point
onehot_data = torch.nn.functional.one_hot(data, num_classes=K)

loss_fn = Trace_ELBO().loss
onehot_elbo = -loss_fn(onehot_model, guide, onehot_data, onehot_logits)
print("vanilla onehot elbo", onehot_elbo)

oneversus_logits = onehot_logits[data] - onehot_logits
oneversus_logits = torch.cat([oneversus_logits[:data], oneversus_logits[data+1:]])

# here we demonstrate the one versus all lower bound without subsampling
oneversus_data = torch.ones(K - 1)
bernoulli_elbo = -loss_fn(bernoulli_model, guide, oneversus_data, oneversus_logits)
print("bernoulli_elbo", bernoulli_elbo)

# here we demonstrate the one versus all lower bound with subsampling
num_samples = 20
subsample_elbo = np.mean([-loss_fn(subsample_model, guide, oneversus_data, oneversus_logits) for _ in range(num_samples)])
print("subsample_elbo", subsample_elbo)

the output should look something like

onehot elbo -1.049773097038269
bernoulli_elbo -1.4169166088104248
subsample_elbo -1.4012184381484984