validate_args=True. What does this do?

I noticed that if I run

x2_grid = jnp.linspace(-20,70,1000)
sigma= jnp.exp(dist.Uniform(low=0, high=50).log_prob(x2_grid))

I obtain this plot

bokeh_plot - 2022-12-28T091435.115

while if I run

x2_grid = jnp.linspace(-20,70,1000)
sigma= jnp.exp(dist.Uniform(low=0, high=50, validate_args=True).log_prob(x2_grid))

I obtain what I expected, namely:

bokeh_plot - 2022-12-28T091454.438

In the documentation it says that this parameter " Whether to enable validation of distribution parameters and arguments to .log_prob method." But it is not clear to me what it means, and why it returns a constant value if I run it without the validation.

with validate_args=True the parameters of the distribution and the sample values passed to log_prob are checked to see if they satisfy required constraints. for example the scale of a Normal distribution needs to be positive and a sample value passed to LogNormal.log_prob must be positive.

in the case of the uniform distribution the log_prob is log(1/(high - low)) in support and -inf out of support. so in the validate_args=True case extra checks are done so that you correctly get -inf out of support. in the validate_args=False case these checks are skipped (for additional speed) under the assumption that you’re on the hook for passing in sample values that are actually in support and results are not guaranteed to be correct (or even not nans) if you pass in sample values that are out of support

1 Like