Does anyone know why Pytorch’s and Numpyro’s Categorical distributions give different gradients w.r.t. probs? (For comparison also included taking the log of the probs)
import torch
import torch.distributions as torchdist
import jax.numpy as jnp
import numpyro.distributions as numdist
from jax import grad
print("TORCH Categorical")
z = torch.arange(2)
x = torch.tensor([0.6, 0.4], requires_grad=True)
loss = torch.sum(torchdist.Categorical(probs=x).log_prob(z))
loss.backward()
print(f"x.grad = {x.grad}")
print("TORCH log")
x = torch.tensor([0.6, 0.4], requires_grad=True)
loss = torch.sum(torch.log(x))
loss.backward()
print(f"x.grad = {x.grad}")
print("NUMPYRO Categorical")
z = jnp.arange(2)
x = jnp.array([0.6, 0.4])
loss_fn = lambda x: jnp.sum(numdist.Categorical(probs=x).log_prob(z))
x_grad = grad(loss_fn)(x)
print(f"x_grad = {x_grad}")
print("NUMPYRO log")
x = jnp.array([0.6, 0.4])
loss_fn = lambda x: jnp.sum(jnp.log(x))
x_grad = grad(loss_fn)(x)
print(f"x_grad = {x_grad}")
Output
TORCH Categorical
x.grad = tensor([-0.3333, 0.5000])
TORCH log
x.grad = tensor([1.6667, 2.5000])
NUMPYRO Categorical
x_grad = [1.6667, 2.5]
NUMPYRO log
x_grad = [1.6667, 2.5]