Slow numpyro inference with custom factor

Hi all. I want to run Bayesian inference in Numpyro with a model that uses a custom factor. This factor is a bivariate Gaussian copula which I have implemented in two ways. The first:

def gaussian_copula_lpdf(u, v, rho):
    u_2 = jnp.square(u)
    v_2 = jnp.square(v)
    rho_2 = jnp.square(rho)
    return (
        -0.5 * (1 - rho_2) - (
            rho_2 * (u_2 + v_2) - 2 * rho * u * v
        ) / (2 * (1 - rho_2))
    )

gives me no issues at all and sampling is pleasantly quick (a couple of seconds).

However, I wanted to implement the same thing via a multivariate likelihood (ideally generalising to more than two variables). The representation is as follows:

def multivar_gaussian_copula_lpdf(u, v, rho) -> float:
    std_gaussian_rvs = jnp.array([u, v])
    cov = jnp.array([[1., rho], [rho, 1]])
    llhood = jnp.log(
        jnp.linalg.det(cov) ** (-1/2) * jnp.exp(
            -0.5 * jnp.matmul(
                std_gaussian_rvs.T,
                jnp.matmul(
                    jnp.linalg.inv(cov) - jnp.identity(2),
                    std_gaussian_rvs
                )
            )
        )
    )
    return llhood

However, takes far slower (estimated time of 30mins) which indicates something is clearly going wrong. My thoughts are that I’ve missed jax-ifying one of the variables but I can’t quite find it.

Any help would be greatly appreciated!

====
PS: For reference the section in my numpyro model which calls these functions is:

    ...
    rho_LY_val = numpyro.sample('rho_ly',
        numpyro.distributions.ImproperUniform(
            numpyro.distributions.constraints.interval(-1., 1.),
            batch_shape=(), 
            event_shape=()
        )
    )
    
    std_normal_L = dist.Normal(0, 1).icdf(quantiles_L)
    std_normal_Y = dist.Normal(0, 1).icdf(quantiles_Y)

    # cop_log_prob = numpyro.factor('cop_log_prob', gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val))
    cop_log_prob = numpyro.factor(
        'cop_log_prob',
        multivar_gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val)
        # gaussian_copula_lpdf(std_normal_L, std_normal_Y, rho_LY_val)
    )

never use det and inv. instead compute a cholesky factorization and compute the logdet like 2 * cholesky_factor.diag().log().sum()) and use triangular_solve and similar to deal with the inverse. of course there may be additional issues with your implementation.

1 Like

Thanks for clarifying @martinjankowiak – this fixed the issue!