# Slow numpyro inference with custom factor

Hi all. I want to run Bayesian inference in Numpyro with a model that uses a custom factor. This factor is a bivariate Gaussian copula which I have implemented in two ways. The first:

``````def gaussian_copula_lpdf(u, v, rho):
u_2 = jnp.square(u)
v_2 = jnp.square(v)
rho_2 = jnp.square(rho)
return (
-0.5 * (1 - rho_2) - (
rho_2 * (u_2 + v_2) - 2 * rho * u * v
) / (2 * (1 - rho_2))
)
``````

gives me no issues at all and sampling is pleasantly quick (a couple of seconds).

However, I wanted to implement the same thing via a multivariate likelihood (ideally generalising to more than two variables). The representation is as follows:

``````def multivar_gaussian_copula_lpdf(u, v, rho) -> float:
std_gaussian_rvs = jnp.array([u, v])
cov = jnp.array([[1., rho], [rho, 1]])
llhood = jnp.log(
jnp.linalg.det(cov) ** (-1/2) * jnp.exp(
-0.5 * jnp.matmul(
std_gaussian_rvs.T,
jnp.matmul(
jnp.linalg.inv(cov) - jnp.identity(2),
std_gaussian_rvs
)
)
)
)
return llhood
``````

However, takes far slower (estimated time of 30mins) which indicates something is clearly going wrong. My thoughts are that I’ve missed jax-ifying one of the variables but I can’t quite find it.

Any help would be greatly appreciated!

====
PS: For reference the section in my numpyro model which calls these functions is:

``````    ...
rho_LY_val = numpyro.sample('rho_ly',
numpyro.distributions.ImproperUniform(
numpyro.distributions.constraints.interval(-1., 1.),
batch_shape=(),
event_shape=()
)
)

std_normal_L = dist.Normal(0, 1).icdf(quantiles_L)
std_normal_Y = dist.Normal(0, 1).icdf(quantiles_Y)

# cop_log_prob = numpyro.factor('cop_log_prob', gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val))
cop_log_prob = numpyro.factor(
'cop_log_prob',
multivar_gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val)
# gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val)
)
``````

never use `det` and `inv`. instead compute a cholesky factorization and compute the logdet like `2 * cholesky_factor.diag().log().sum())` and use `triangular_solve` and similar to deal with the inverse. of course there may be additional issues with your implementation.

1 Like

Thanks for clarifying @martinjankowiak – this fixed the issue!