Hello,
I am having issues using the WishartTril distribution from TFD in numpyro.
import numpyro
from numpyro.infer import MCMC, NUTS
import tensorflow_probability.substrates.jax as tfp
tfd = tfp.distributions
# from numpyro.contrib.tfp.distributions import TFPDistribution
def model():
numpyro.sample("Omega", tfd.WishartTriL(df=5.0, scale_tril=jnp.eye(2)))
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(jr.PRNGKey(0))
I get this error:
This happens also if I use the wrapper TFPDistribution[tfd.WishartTriL].
Thanks for the support!
Versions:
jax: 0.6.2
numpyro: 0.19.0
tfp: 0.26.0-dev20250903
