I’m trying to use NUTS to marginalize out unknown inputs in a GP model. This works ok when the input is continuous. I tried to do it very naively for a discrete input space:
def ma_gp_hs(X_obs, n_hyp, X_acc):
alpha = torch.tensor([1.0])
sigma = pyro.sample('sigma', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([1])))
rho = pyro.sample('rho', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([100])))
probs = pyro.sample('probs', dist.Uniform(torch.zeros(n_hyp), torch.ones(n_hyp)))
probs += 1
X_hyp = X_acc[probs.long()]
X_train = torch.cat([X_obs, X_hyp])
# Covariance
n = len(X_train)
cov = Matern52(X_train, X_train, alpha, rho) + torch.eye(n) * (sigma + 1e-5)
L = torch.potrf(cov, upper=False)
# Likelihood
return pyro.sample('f', dist.MultivariateNormal(torch.zeros(n), scale_tril=L))
def conditioned_gp_regressor(gp_regressor, X_obs, y_train, n_hyp, X_acc):
return pyro.poutine.condition(gp_regressor, data={'f': y_train})(X_obs, n_hyp, X_acc)
nuts_kernel = NUTS(conditioned_gp_regressor, step_size=1, adapt_step_size=True)
%time hpost = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500).run(ma_gp_hs, X_obs, y_train, len(X_hyp), X_hyp)
This takes a very very long time to run (24 hours in, it hasn’t gotten to step 75). Thinking about how MHC/NUTS works, I think this is because discretizing ‘probs’ and then using it as an indices means that there are no gradients flowing back to that latent site, and so NUTS does a bunch of evaluations trying to get a gradient instead of getting anywhere.
So instead, I tried to use a RelaxedCategorical, since this is a continuous distribution:
def ma_gp_hs(X_obs, n_hyp, X_acc):
alpha = torch.tensor([1.0])
sigma = pyro.sample('sigma', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([1])))
rho = pyro.sample('rho', dist.Uniform(torch.Tensor([1e-3]), torch.Tensor([100])))
prior = torch.stack([torch.ones(len(X_acc)) for _ in range(n_hyp)])
probs = pyro.sample('probs', dist.RelaxedOneHotCategorical(0.001, probs=prior))
X_hyp = probs @ X_acc
X_train = torch.cat([X_obs, X_hyp])
# Covariance
n = len(X_train)
cov = Matern52(X_train, X_train, alpha, rho) + torch.eye(n) * (sigma + 1e-5)
L = torch.potrf(cov, upper=False)
# Likelihood
return pyro.sample('f', dist.MultivariateNormal(torch.zeros(n), scale_tril=L))
def conditioned_gp_regressor(gp_regressor, X_obs, y_train, n_hyp, X_acc):
return pyro.poutine.condition(gp_regressor, data={'f': y_train})(X_obs, n_hyp, X_acc)
nuts_kernel = NUTS(conditioned_gp_regressor, step_size=1, adapt_step_size=True)
%time hpost = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500).run(ma_gp_hs, X_obs, y_train, len(X_hyp), X_hyp)
But this results in the following error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
228 try:
--> 229 site_log_p = site["log_prob_sum"]
230 except KeyError:
KeyError: 'log_prob_sum'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
<timed exec> in <module>()
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
81 """
82 self._init()
---> 83 for tr, logit in poutine.block(self._traces)(*args, **kwargs):
84 self.exec_traces.append(tr)
85 self.log_weights.append(logit)
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/mcmc/mcmc.py in _traces(self, *args, **kwargs)
30
31 def _traces(self, *args, **kwargs):
---> 32 self.kernel.setup(*args, **kwargs)
33 trace = self.kernel.initial_trace()
34 self.logger.info("Starting MCMC using kernel - {} ...".format(self.kernel.__class__.__name__))
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in setup(self, *args, **kwargs)
173 if node["fn"].support is not constraints.real and self._automatic_transform_enabled:
174 self.transforms[name] = biject_to(node["fn"].support).inv
--> 175 self._validate_trace(trace)
176
177 if self.adapt_step_size:
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/infer/mcmc/hmc.py in _validate_trace(self, trace)
151
152 def _validate_trace(self, trace):
--> 153 trace_log_prob_sum = trace.log_prob_sum()
154 if torch_isnan(trace_log_prob_sum) or torch_isinf(trace_log_prob_sum):
155 raise ValueError("Model specification incorrect - trace log pdf is NaN or Inf.")
~/anaconda/envs/pyro/lib/python3.6/site-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
230 except KeyError:
231 args, kwargs = site["args"], site["kwargs"]
--> 232 site_log_p = site["fn"].log_prob(site["value"], *args, **kwargs)
233 site_log_p = scale_tensor(site_log_p, site["scale"]).sum()
234 site["log_prob_sum"] = site_log_p
~/anaconda/envs/pyro/lib/python3.6/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
84 y = x
85
---> 86 log_prob += _sum_rightmost(self.base_dist.log_prob(y),
87 event_dim - len(self.base_dist.event_shape))
88 return log_prob
~/anaconda/envs/pyro/lib/python3.6/site-packages/torch/distributions/relaxed_categorical.py in log_prob(self, value)
66 self._validate_sample(value)
67 logits, value = broadcast_all(self.logits, value)
---> 68 log_scale = (self.temperature.new(self.temperature.shape).fill_(K).lgamma() -
69 self.temperature.log().mul(-(K - 1)))
70 score = logits - value.mul(self.temperature)
AttributeError: 'float' object has no attribute 'new'