Hi, I’m new to pyro and struggling to figure out whether pyro is appropriate to replicate Francisco J. R. Ruiz, Susan Athey and David M. Blei’s paper (https://arxiv.org/pdf/1711.03560.pdf) ‘SHOPPER: A PROBABILISTIC MODEL OF CONSUMER CHOICE WITH SUBSTITUTES AND COMPLEMENTS’.
Now I’m stuck in one specific question about the inference of the model, the one-vs-each bound to approximate a big softmax. (page 31)
where eta is at page 10 (Sorry, new user is only allowed to post one figure, so I combine it to the right.)
Although within this context, I believe what is important is how to do this in pyro?
(the one-vs-each bound is introduced in paper ‘One-vs-each approximation to softmax for scalable estimation of probabilities’ https://arxiv.org/pdf/1609.07410.pdf)
the main complications in doing this arguably lie in doing the relevant pytorch tensor indexing gymnastics. here’s an example that demonstrates that the lower bound is in fact a lower bound for a single categorical datapoint. generalizing this should be pretty straightforward (if tedious).
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Trace_ELBO
def onehot_model(onehot_data, onehot_logits):
pyro.sample('data', dist.OneHotCategorical(logits=onehot_logits), obs=onehot_data)
def bernoulli_model(data, oneversus_logits):
pyro.sample('data', dist.Bernoulli(logits=oneversus_logits), obs=data)
def subsample_model(data, oneversus_logits):
with pyro.plate("data_plate", oneversus_logits.size(0), subsample_size=1) as idx:
pyro.sample('data', dist.Bernoulli(logits=oneversus_logits[idx]), obs=data[idx])
def guide(*args):
pass
K = 4
onehot_logits = torch.randn(K)
data = torch.tensor([1]) # define single data point
onehot_data = torch.nn.functional.one_hot(data, num_classes=K)
loss_fn = Trace_ELBO().loss
onehot_elbo = -loss_fn(onehot_model, guide, onehot_data, onehot_logits)
print("vanilla onehot elbo", onehot_elbo)
oneversus_logits = onehot_logits[data] - onehot_logits
oneversus_logits = torch.cat([oneversus_logits[:data], oneversus_logits[data+1:]])
# here we demonstrate the one versus all lower bound without subsampling
oneversus_data = torch.ones(K - 1)
bernoulli_elbo = -loss_fn(bernoulli_model, guide, oneversus_data, oneversus_logits)
print("bernoulli_elbo", bernoulli_elbo)
# here we demonstrate the one versus all lower bound with subsampling
num_samples = 20
subsample_elbo = np.mean([-loss_fn(subsample_model, guide, oneversus_data, oneversus_logits) for _ in range(num_samples)])
print("subsample_elbo", subsample_elbo)