Hi all. I want to run Bayesian inference in Numpyro with a model that uses a custom factor. This factor is a bivariate Gaussian copula which I have implemented in two ways. The first:
def gaussian_copula_lpdf(u, v, rho):
u_2 = jnp.square(u)
v_2 = jnp.square(v)
rho_2 = jnp.square(rho)
return (
-0.5 * (1 - rho_2) - (
rho_2 * (u_2 + v_2) - 2 * rho * u * v
) / (2 * (1 - rho_2))
)
gives me no issues at all and sampling is pleasantly quick (a couple of seconds).
However, I wanted to implement the same thing via a multivariate likelihood (ideally generalising to more than two variables). The representation is as follows:
def multivar_gaussian_copula_lpdf(u, v, rho) -> float:
std_gaussian_rvs = jnp.array([u, v])
cov = jnp.array([[1., rho], [rho, 1]])
llhood = jnp.log(
jnp.linalg.det(cov) ** (-1/2) * jnp.exp(
-0.5 * jnp.matmul(
std_gaussian_rvs.T,
jnp.matmul(
jnp.linalg.inv(cov) - jnp.identity(2),
std_gaussian_rvs
)
)
)
)
return llhood
However, takes far slower (estimated time of 30mins) which indicates something is clearly going wrong. My thoughts are that I’ve missed jax-ifying one of the variables but I can’t quite find it.
Any help would be greatly appreciated!
====
PS: For reference the section in my numpyro model which calls these functions is:
...
rho_LY_val = numpyro.sample('rho_ly',
numpyro.distributions.ImproperUniform(
numpyro.distributions.constraints.interval(-1., 1.),
batch_shape=(),
event_shape=()
)
)
std_normal_L = dist.Normal(0, 1).icdf(quantiles_L)
std_normal_Y = dist.Normal(0, 1).icdf(quantiles_Y)
# cop_log_prob = numpyro.factor('cop_log_prob', gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val))
cop_log_prob = numpyro.factor(
'cop_log_prob',
multivar_gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val)
# gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val)
)