Different gradients for Pytorch and Numpyro Categorical distributions

Does anyone know why Pytorch’s and Numpyro’s Categorical distributions give different gradients w.r.t. probs? (For comparison also included taking the log of the probs)

import torch
import torch.distributions as torchdist
import jax.numpy as jnp
import numpyro.distributions as numdist
from jax import grad

print("TORCH Categorical")
z = torch.arange(2)
x = torch.tensor([0.6, 0.4], requires_grad=True)
loss = torch.sum(torchdist.Categorical(probs=x).log_prob(z))
loss.backward()
print(f"x.grad = {x.grad}")

print("TORCH log")
x = torch.tensor([0.6, 0.4], requires_grad=True)
loss = torch.sum(torch.log(x))
loss.backward()
print(f"x.grad = {x.grad}")

print("NUMPYRO Categorical")
z = jnp.arange(2)
x = jnp.array([0.6, 0.4])
loss_fn = lambda x: jnp.sum(numdist.Categorical(probs=x).log_prob(z))
x_grad = grad(loss_fn)(x)
print(f"x_grad = {x_grad}")

print("NUMPYRO log")
x = jnp.array([0.6, 0.4])
loss_fn = lambda x: jnp.sum(jnp.log(x))
x_grad = grad(loss_fn)(x)
print(f"x_grad = {x_grad}")

Output

TORCH Categorical
x.grad = tensor([-0.3333,  0.5000])
TORCH log
x.grad = tensor([1.6667, 2.5000])
NUMPYRO Categorical
x_grad = [1.6667, 2.5]
NUMPYRO log
x_grad = [1.6667, 2.5]

i think torch includes a

self.probs = probs / probs.sum(-1, keepdim=True)

that numpyro does not

That sounds right. Thanks!

Interesting. I feel that taking grad on a non-real domain is dangerous.