Hi there,
I am trying to implement a variational approximation algorithm using PyTorch. I am trying to use Pyro/NumPyro to help test aspects of the algorithm and, second, to maybe try to pull out values such as the model log density.
Firstly, I wrote a toy logistic linear mixed model in NumPyro as
def linear_mixed_model(y, X, Z):
zeta = numpyro.sample('zeta', dist.Normal(0, 10))
scale = numpyro.deterministic('scale', jnp.exp(-zeta))
with numpyro.plate('b_plate', Z.shape[1]):
b = numpyro.sample('b', dist.Normal(0, scale))
with numpyro.plate('beta_plate', X.shape[1]):
beta = numpyro.sample('beta', dist.Normal(0, 10))
eta = numpyro.deterministic('eta', X @ beta.reshape(-1, 1) + Z @ b.reshape(-1, 1))
with numpyro.plate('y_plate', y.shape[0]):
numpyro.sample('y', dist.Bernoulli(logits=eta.flatten()), obs=y)
The corresponding function that I have written to evaluate the log density of the model using torch.distributions
is
import torch
from torch.distributions import Normal, Bernoulli
from torch.distributions.transforms import ExpTransform
from torch.distributions.transformed_distribution import TransformedDistribution
def logp(
params,
data,
logp=0):
# unpack model parameters and data
beta, zeta, b = params['beta'], params['zeta'], params['b']
y, X, Z = data["y"], data["X"], data["Z"]
# transform
scale = torch.exp(-zeta)
zeta_dist = Normal(0, 10)
scale_dist = TransformedDistribution(
zeta_dist, [ExpTransform()]
)
# priors
logp += scale_dist.log_prob(scale).sum()
logp += Normal(0, 10).log_prob(beta).sum()
logp += Normal(0, scale).log_prob(b).sum()
# likelihood
linear_predictor = X @ beta.reshape(-1, 1) + Z @ b.reshape(-1, 1)
logp += Bernoulli(logits=linear_predictor.flatten()).log_prob(y).sum()
return logp
When running MCMC on the NumPyro model, I return the potential_energy
value using the extra_fields
argument. The negative of the potential_energy
value should be equal to the output of the above logp
function, given the same set of parameter values are used as input, but this is not the case. I am guessing that there is an error in how I am dealing with the parameter transformations using torch.distributions
. Does anyone have any pointers?
Secondly, I love using the numpyro.infer.util.log_density
function and was wondering if there was a corresponding function that gave the gradients of the log density concerning the model parameters and whether Pyro had related functionality.
Thanks!