I’m trying to set up a multivariate log normal distribution in Numpyro, but I’m getting some sort of type error. To my knowledge, all that needs doing is inheriting the Distribution class and then filling in how the distribution calculates the log probability (def log_prob) and how it samples (def sample). Can anyone help? Has this been done already maybe?
Thanks
if you mean the distribution that you get when you exponentiate a multivariate normal random variable you’re probably better off using TransformedDistribution
and ExpTransform
https://num.pyro.ai/en/stable/distributions.html#numpyro.distributions.transforms.ExpTransform
This works perfectly! Thanks