Is there a clean way to reparameterize a continuous model in numpyro using handlers for arbitrary transformations? i.e. I have some model f(x,y), and some other set of parameters u(x,y), v(x,y), and I want to transform f(x,y) to a new probability density in terms of u and v. I know the mathematical transformation is simply a correction by det|J|:
Where J(u,v) is the jacobian of the transformation (u,v) \rightarrow (x,y). Obviously, I can do this using log_density
and some messy workarounds, e.g. suppose I have u=(x+2y)^{-3} and v=(x-2y)^{-1}, I can do as shown below, but this hardly seems like the cleanest way to do things.
What is the recommended procedure for constructing and applying such transforms in NumPyro?
# Forwards Transformation: x,y - > u,v
def uf(x,y):
return (x+2*y)**3
def vf(x,y):
return(1/(x-2*y))
def Jxy(x,y):
out = jnp.array([jax.grad(uf, argnums=[0,1])(x,y), jax.grad(vf, argnums=[0,1])(x,y)])
out = jnp.linalg.det(out)
return out
#-----------------
# Backwards Transformation: u,v - > x,y
def xf(u,v):
return 1/2 * (1/u**(1/3) + 1/v)
def yf(u,v):
return 1/4 * (1/u**(1/3) - 1/v)
def Juv(u,v):
out = jnp.array([jax.grad(xf, argnums=[0,1])(u,v), jax.grad(yf, argnums=[0,1])(u,v)])
out = jnp.linalg.det(out)
return out
#----------------
# Models
def model():
'''Base Model, f(x,y)'''
numpyro.sample('x', dist.Normal(0,1))
numpyro.sample('y', dist.HalfNormal(2))
@jax.jit
def like_uv(u,v):
x, y = xf(u,v), yf(u,v)
detJ = Juv(u,v)
logd_xy = numpyro.infer.util.log_density(model, (), {}, {'x':x,'y':y})[0]
out = logd_xy*detJ
return(out)
def model_uv():
'''Transformed Model, g(u,v)'''
u = numpyro.sample('u', dist.ImproperUniform(dist.constraints.real, (), ()))
v = numpyro.sample('v', dist.ImproperUniform(dist.constraints.real, (), ()))
numpyro.factor('transformed_density', like_uv(u,v) )
@jax.jit
def like_uv2(u,v):
return numpyro.infer.util.log_density(model_uv, (), {}, {'u':u, 'v':v})[0]
Note, I’ve tested the above and confirmed that like_uv
and `like_uv2’ do produce the same probability density: