Issue with WishartTril wrapper

Hello,

I am having issues using the WishartTril distribution from TFD in numpyro.

import numpyro
from numpyro.infer import MCMC, NUTS
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
# from numpyro.contrib.tfp.distributions import TFPDistribution


def model():
    numpyro.sample("Omega", tfd.WishartTriL(df=5.0, scale_tril=jnp.eye(2)))

kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(jr.PRNGKey(0))

I get this error:

This happens also if I use the wrapper TFPDistribution[tfd.WishartTriL].

Thanks for the support!

Versions:
jax: 0.6.2
numpyro: 0.19.0
tfp: 0.26.0-dev20250903

you can try WishartCholesky instead

1 Like

Thank you so much!