Comparing log probability between Pyro and torch.distributions implementation

Hi there,

I am trying to implement a variational approximation algorithm using PyTorch. I am trying to use Pyro/NumPyro to help test aspects of the algorithm and, second, to maybe try to pull out values such as the model log density.

Firstly, I wrote a toy logistic linear mixed model in NumPyro as

def linear_mixed_model(y, X, Z):
    zeta = numpyro.sample('zeta', dist.Normal(0, 10))
    scale = numpyro.deterministic('scale', jnp.exp(-zeta))
    with numpyro.plate('b_plate', Z.shape[1]):
        b = numpyro.sample('b', dist.Normal(0, scale))
    with numpyro.plate('beta_plate', X.shape[1]):
        beta = numpyro.sample('beta', dist.Normal(0, 10))
    eta = numpyro.deterministic('eta', X @ beta.reshape(-1, 1) + Z @ b.reshape(-1, 1))
    with numpyro.plate('y_plate', y.shape[0]):
        numpyro.sample('y', dist.Bernoulli(logits=eta.flatten()), obs=y)

The corresponding function that I have written to evaluate the log density of the model using torch.distributions is

import torch 
from torch.distributions import Normal, Bernoulli
from torch.distributions.transforms import ExpTransform 
from torch.distributions.transformed_distribution import TransformedDistribution

def logp(
    # unpack model parameters and data
    beta, zeta, b = params['beta'], params['zeta'], params['b']
    y, X, Z = data["y"], data["X"], data["Z"]
    # transform
    scale = torch.exp(-zeta)
    zeta_dist = Normal(0, 10)
    scale_dist = TransformedDistribution(
        zeta_dist, [ExpTransform()]
    # priors 
    logp += scale_dist.log_prob(scale).sum()
    logp += Normal(0, 10).log_prob(beta).sum()
    logp += Normal(0, scale).log_prob(b).sum()
    # likelihood 
    linear_predictor = X @ beta.reshape(-1, 1) + Z @ b.reshape(-1, 1)
    logp += Bernoulli(logits=linear_predictor.flatten()).log_prob(y).sum()

    return logp

When running MCMC on the NumPyro model, I return the potential_energy value using the extra_fields argument. The negative of the potential_energy value should be equal to the output of the above logp function, given the same set of parameter values are used as input, but this is not the case. I am guessing that there is an error in how I am dealing with the parameter transformations using torch.distributions. Does anyone have any pointers?

Secondly, I love using the numpyro.infer.util.log_density function and was wondering if there was a corresponding function that gave the gradients of the log density concerning the model parameters and whether Pyro had related functionality.


haven’t tried to make a line-by-line comparison but e.g. the numpyro model is defined w.r.t. zeta

zeta = numpyro.sample('zeta', dist.Normal(0, 10))

but your torch model is effectively defined w.r.t. scale = torch.exp(-zeta) and as such the log densities differ by a log jacobian factor. there may be other such differences but if you want an apples-to-apples comparison you need to make sure you’re comparing densities defined w.r.t. the same variables.

unfortunately i don’t think pyro currently has an equivalent of numpyro.infer.util.log_density. contributions welcome!

1 Like

Awesome, thanks for the response Martin. I’m busy in January, but want to put aside some time in February to try contribute this.