Gradient of numerically approximated normalization constants

Hi,

I have a distribution with an intractable normalization constant estimated using numerical integration,

f_{X}(x\mid\lambda)=\frac{\tilde{f_{X}}(x\mid\lambda)}{Z(\lambda)}\implies Z(\lambda)=\int \tilde{f_{X}}(x\mid\lambda) dx

where, \tilde{f_{X}} is the unnormalized density and Z is the normalization constant, \lambda can be considered as a single or set of parameters.

When I take the gradient of its log_prob it is off to some degree from the gradient calculated through forward difference. My understanding is, I have stopped the gradient over the normalization constant because of interpolation, otherwise, it was coming out to be jnp.nan. Is there a way we can calculate the gradients accurately?

probably impossible to answer this question without further details.

Do you want to look at the code? It is a little messy!

not really but you could describe your approach or what the numerical issue is without lots of code

Suppose \tilde{f} is an unnormalized density over the interval [a,b], it’s normalization constant is Z. This is expressed as,

Z(\lambda)=\int_{a}^{b}\tilde{f}(x\mid\lambda)dx.

\tilde{f} is little problematic, suppose c\in[a,b], we can transform integral into,

Z(\lambda)=\int_{a}^{c}\tilde{f}(x\mid\lambda)dx+\int_{c}^{b}\tilde{f}(x\mid\lambda)dx,

and after some simplification integral over [c,b] reduces to a closed form, on the other hand integral over [a,c] does not have any closed form so we have to numerically approximate it, some times the constant depends on one input of a multivariate distribution so we have to interpolate it too. In experiments we observe that we are getting jnp.nan from the interpolation and numerically approximated part. To avoid nans I closed all interpolation and numerical integration in jax.lax.stop_gradient.

well either you need to improve the stability of your numerical approximation or you need to directly provide some approximation scheme for \frac{\partial Z}{\partial \lambda}