NUTS with 'discrete' latent site very slow

I’m trying to use NUTS to marginalize out unknown inputs in a GP model. This works ok when the input is continuous. I tried to do it very naively for a discrete input space:

def ma_gp_hs(X_obs, n_hyp, X_acc):
    alpha = torch.tensor([1.0])
    sigma = pyro.sample('sigma', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([1])))
    rho = pyro.sample('rho', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([100])))
    probs = pyro.sample('probs', dist.Uniform(torch.zeros(n_hyp), torch.ones(n_hyp)))
    probs += 1
    X_hyp = X_acc[probs.long()]
    X_train = torch.cat([X_obs, X_hyp])
    # Covariance
    n = len(X_train)
    cov = Matern52(X_train, X_train, alpha, rho) + torch.eye(n) * (sigma + 1e-5)
    L = torch.potrf(cov, upper=False)
    # Likelihood
    return pyro.sample('f', dist.MultivariateNormal(torch.zeros(n), scale_tril=L))
    
def conditioned_gp_regressor(gp_regressor, X_obs, y_train, n_hyp, X_acc):
    return pyro.poutine.condition(gp_regressor, data={'f': y_train})(X_obs, n_hyp, X_acc)

nuts_kernel = NUTS(conditioned_gp_regressor, step_size=1, adapt_step_size=True)
%time hpost = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500).run(ma_gp_hs, X_obs, y_train, len(X_hyp), X_hyp)

This takes a very very long time to run (24 hours in, it hasn’t gotten to step 75). Thinking about how MHC/NUTS works, I think this is because discretizing ‘probs’ and then using it as an indices means that there are no gradients flowing back to that latent site, and so NUTS does a bunch of evaluations trying to get a gradient instead of getting anywhere.

So instead, I tried to use a RelaxedCategorical, since this is a continuous distribution:

def ma_gp_hs(X_obs, n_hyp, X_acc):
    alpha = torch.tensor([1.0])
    sigma = pyro.sample('sigma', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([1])))
    rho = pyro.sample('rho', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([100])))
    prior = torch.stack([torch.ones(len(X_acc)) for _ in range(n_hyp)])
    probs = pyro.sample('probs', dist.RelaxedOneHotCategorical(0.001, probs=prior))
    X_hyp = probs @ X_acc
    X_train = torch.cat([X_obs, X_hyp])
    # Covariance
    n = len(X_train)
    cov = Matern52(X_train, X_train, alpha, rho) + torch.eye(n) * (sigma + 1e-5)
    L = torch.potrf(cov, upper=False)
    # Likelihood
    return pyro.sample('f', dist.MultivariateNormal(torch.zeros(n), scale_tril=L))
    
def conditioned_gp_regressor(gp_regressor, X_obs, y_train, n_hyp, X_acc):
    return pyro.poutine.condition(gp_regressor, data={'f': y_train})(X_obs, n_hyp, X_acc)

nuts_kernel = NUTS(conditioned_gp_regressor, step_size=1, adapt_step_size=True)
%time hpost = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500).run(ma_gp_hs, X_obs, y_train, len(X_hyp), X_hyp)

But this results in the following error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
    228                 try:
--> 229                     site_log_p = site["log_prob_sum"]
    230                 except KeyError:

KeyError: 'log_prob_sum'

During handling of the above exception, another exception occurred:

AttributeError                            Traceback (most recent call last)
<timed exec> in <module>()

~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
     81         """
     82         self._init()
---> 83         for tr, logit in poutine.block(self._traces)(*args, **kwargs):
     84             self.exec_traces.append(tr)
     85             self.log_weights.append(logit)

~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
     30 
     31     def _traces(self, *args, **kwargs):
---> 32         self.kernel.setup(*args, **kwargs)
     33         trace = self.kernel.initial_trace()
     34         self.logger.info("Starting MCMC using kernel - {} ...".format(self.kernel.__class__.__name__))

~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in setup(self, *args, **kwargs)
    173             if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
    174                 self.transforms[name] = biject_to(node["fn"].support).inv
--> 175         self._validate_trace(trace)
    176 
    177         if self.adapt_step_size:

~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in _validate_trace(self, trace)
    151 
    152     def _validate_trace(self, trace):
--> 153         trace_log_prob_sum = trace.log_prob_sum()
    154         if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum):
    155             raise ValueError("Model specification incorrect - trace log pdf is NaN or Inf.")

~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
    230                 except KeyError:
    231                     args, kwargs = site["args"], site["kwargs"]
--> 232                     site_log_p = site["fn"].log_prob(site["value"], *args, **kwargs)
    233                     site_log_p = scale_tensor(site_log_p, site["scale"]).sum()
    234                     site["log_prob_sum"] = site_log_p

~/anaconda/envs/pyro/lib/python3.6/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
     84             y = x
     85 
---> 86         log_prob += _sum_rightmost(self.base_dist.log_prob(y),
     87                                    event_dim - len(self.base_dist.event_shape))
     88         return log_prob

~/anaconda/envs/pyro/lib/python3.6/site-packages/torch/distributions/relaxed_categorical.py in log_prob(self, value)
     66             self._validate_sample(value)
     67         logits, value = broadcast_all(self.logits, value)
---> 68         log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
     69                      self.temperature.log().mul(-(K - 1)))
     70         score = logits - value.mul(self.temperature)

AttributeError: 'float' object has no attribute 'new'

I don’t think RelaxedCategorical will work with HMC, but to fix your immediate error, I believe you can make temperature a tensor:

- dist.RelaxedOneHotCategorical(0.001, probs=prior)
+ dist.RelaxedOneHotCategorical(torch.tensor(0.001), probs=prior)

Now I get ValueError: Model specification incorrect - trace log pdf is NaN or Inf. instead. Which is definitely a more sensical error.

I think this is because discretizing ‘probs’ and then using it as an indices means that there are no gradients flowing back to that latent site.

As you rightly deduced, this will certainly not work, because you will be trying to differentiate the joint log probability w.r.t. some arbitrarily sampled probs, which isn’t defined once you do discretization like this. What surprises me is that this even ran without any exceptions.

Update: The following code from #1121 that uses a dirichlet latent runs fine, but it is apparently slow with @kkyang’s original data.

@kkyang - What is the size of your actual data?

# Runs fine on the toy dataset here;
# but slow on the original.

import torch

import pyro
import pyro.distributions as dist

from pyro.infer.mcmc import MCMC, HMC, NUTS
import logging

logging.basicConfig(format='%(levelname)s %(message)s')
logger = logging.getLogger('pyro')
logger.setLevel(logging.INFO)

X1 = torch.tensor([[0.5, -0.5, 0.1],
                   [0.1, 0.6, -0.4]])
X2 = torch.tensor([[0.5, -0.3, 0.6],
                   [-0.1, -0.6, -0.2]])
y12 = torch.tensor([0.4, -0.1, -0.5, 0.2])


def Matern52(X1, X2, alpha, rho):
    L2 = torch.sum(X1 ** 2, dim=1).unsqueeze(1) + torch.sum(X2 ** 2, dim=1).unsqueeze(0) - 2 * torch.mm(X1, X2.t())
    L2 = torch.clamp(L2, min=0.0)
    L = torch.sqrt(L2 + 1e-12)
    sqrt5_r = 5 ** 0.5 * L
    return alpha * (1 + sqrt5_r / rho + (5 / 3) * L2 / rho ** 2) * torch.exp(-sqrt5_r / rho)


def ma_gp_hs(X_obs, n_hyp, X_acc):
    alpha = torch.tensor([1.0])
    sigma = pyro.sample('sigma', dist.Uniform(torch.Tensor([1e-2]), torch.Tensor([1])))
    rho = pyro.sample('rho', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([100])))
    prior = torch.stack([torch.ones(len(X_acc)) * 0.1 for _ in range(n_hyp)])
    probs = pyro.sample('probs', dist.Dirichlet(prior))
    X_hyp = torch.mm(probs, X_acc)
    X_train = torch.cat([X_obs, X_hyp])
    # Covariance
    n = len(X_train)
    cov = Matern52(X_train, X_train, alpha, rho) + torch.eye(n) * (sigma + 1e-3)
    L = torch.potrf(cov, upper=False)
    # Likelihood
    return pyro.sample('f', dist.MultivariateNormal(torch.zeros(n), scale_tril=L))


def conditioned_gp_regressor(gp_regressor, X_obs, y_train, n_hyp, X_acc):
    return pyro.poutine.condition(gp_regressor, data={'f': y_train})(X_obs, n_hyp, X_acc)


nuts_kernel = NUTS(conditioned_gp_regressor, step_size=1, adapt_step_size=True)
hpost = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500).run(ma_gp_hs, X1, y12, len(X2), X2)

X_obs is 20 by 9000, X_acc is 135 by 9000, and y_train is 135. Eventually I’d like to run it with as many as 20,000 rows in X_acc.