I am building a model in numpyro, and I would like a multivariate normal distribution that is truncated such that no dimension is ever < a. Is there a straightforward way to do this?
you might find this helpful
https://num.pyro.ai/en/stable/tutorials/truncated_distributions.html
This is perfect, thanks!
This is a slight change of direction, but what are the advantages and disadvantages of truncating a distribution vs subjecting it to something like a softmax/softplus?
Obviously it changes the shape of the distribution. The resultant function stays continuous/differentiable. It’s also less coding. Anything you can say a priori about the performance?
i think it’s impossible to say in general. it’s all about what modeling assumptions you want to make. if you’re happy with the distribution implicitly defined by some deterministic transformation you introduce, go for it
Hi. How did you solve this? I don’t see any multivariate truncated normal on the link.
Hello everyone, I have also tried to implement the multivariate normal distribution using this truncated distribution function, but whenever I run my SVI It throws an assertion error: exception: no description. Sampling using Predictive as in the tutorial provided by @martinjankowiak works fine though.
Can anyone help me out here?
def TruncatedMultivariateNormal(
loc=jnp.zeros(9), scale=jnp.ones(9), *, low=None, high=None, validate_args=None
):
return TruncatedDistribution(
base_dist=dist.MultivariateNormal(loc, scale),
low=low,
high=high,
validate_args=validate_args,
)