Hi,
I have a distribution with an intractable normalization constant estimated using numerical integration,
f_{X}(x\mid\lambda)=\frac{\tilde{f_{X}}(x\mid\lambda)}{Z(\lambda)}\implies Z(\lambda)=\int \tilde{f_{X}}(x\mid\lambda) dx
where, \tilde{f_{X}} is the unnormalized density and Z is the normalization constant, \lambda can be considered as a single or set of parameters.
When I take the gradient of its log_prob
it is off to some degree from the gradient calculated through forward difference. My understanding is, I have stopped the gradient over the normalization constant because of interpolation, otherwise, it was coming out to be jnp.nan
. Is there a way we can calculate the gradients accurately?
probably impossible to answer this question without further details.
Do you want to look at the code? It is a little messy!
not really but you could describe your approach or what the numerical issue is without lots of code
Suppose \tilde{f} is an unnormalized density over the interval [a,b], it’s normalization constant is Z. This is expressed as,
Z(\lambda)=\int_{a}^{b}\tilde{f}(x\mid\lambda)dx.
\tilde{f} is little problematic, suppose c\in[a,b], we can transform integral into,
Z(\lambda)=\int_{a}^{c}\tilde{f}(x\mid\lambda)dx+\int_{c}^{b}\tilde{f}(x\mid\lambda)dx,
and after some simplification integral over [c,b] reduces to a closed form, on the other hand integral over [a,c] does not have any closed form so we have to numerically approximate it, some times the constant depends on one input of a multivariate distribution so we have to interpolate it too. In experiments we observe that we are getting jnp.nan
from the interpolation and numerically approximated part. To avoid nans I closed all interpolation and numerical integration in jax.lax.stop_gradient
.
well either you need to improve the stability of your numerical approximation or you need to directly provide some approximation scheme for \frac{\partial Z}{\partial \lambda}