Seeking help to write Likelihood function of a non-homogeneous Poisson Process!

Hello,

I am working with astrophysical models for parameter estimation of Binary Black Holes. The setup requires running a non-homogeneous Poisson Process with the likelihood defined as,

\displaystyle\mathcal{L}=\exp{\left(-\mu(\mathcal{R}, \Lambda)\right)}\prod_{n=1}^{N}\int p(\lambda | d_n)\mathcal{R}p(\lambda | \Lambda)\frac{1}{p(\lambda)}d\lambda\approx \exp{\left(-\mu(\mathcal{R}, \Lambda)\right)}\prod_{n=1}^{N}\frac{\mathcal{R}}{N_i}\sum_{i=1}^{N_i}\frac{p(\lambda_i|\Lambda)}{p(\lambda_i)}

The integral can be approximated to a Monte Carlo Integration, where the \lambda is drawn from the p(\lambda | d_n), and they are pre-computed and stored in files for each n. I am facing trouble in writing all of this in numpyro.

Here \lambda=(m_1,m_2) and \Lambda = (\alpha, k, m_\min, m_\max, \mathcal{R}) and,

\displaystyle p(\lambda|\Lambda)\propto\frac{m_1^{-\alpha-k}m_2^k}{m_1-m_\min}

We have taken k=0. This is the current code I am using for it,

import numpyro
from jax import numpy as jnp
from numpyro import distributions as dist

def model(lambda_n):  # , raw_interpolator):
    r"""
    .. math::
        \mathcal{L}(\mathcal{R},\Lambda)\propto\prod_{n=1}^{N}\frac{\mathcal{R}}{N_i}\sum_{i=1}^{N_i}\frac{p(\lambda_i/\Lambda)}{\pi(\lambda_i)}
    """
    event, _, _ = lambda_n.shape
    alpha_prior = dist.Uniform(-5.0, 5.0)
    mmin_prior = dist.Uniform(5.0, 15.0)
    mmax_prior = dist.Uniform(30.0, 190.0)
    rate_prior = dist.LogUniform(10**-1, 10**6)
    alpha = numpyro.sample("alpha", alpha_prior)
    mmin = numpyro.sample("mmin", mmin_prior)
    mmax = numpyro.sample("mmax", mmax_prior)
    rate = numpyro.sample("rate", rate_prior)

    ll = 0.0

    with numpyro.plate("data", event) as n:
        mass_model = Wysocki2019MassModel(
            alpha_m=alpha,
            k=0,
            mmin=mmin,
            mmax=mmax,
        )
        ll_i = 0.0
        p = jnp.exp(mass_model.log_prob(lambda_n[n, :, :]))
        ll_i = jnp.mean(p, axis=1)
        ll_i *= rate
        ll += jnp.log(ll_i)

    # mean = expval_mc(alpha, mmin, mmax, rate, raw_interpolator)
    # ll -= mean

    return jnp.exp(ll)

Right now I am avoiding the exponential term but you can advise on it if you can.

Thank you very much!

FYI

#  Copyright 2023 The GWKokab Authors
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from __future__ import annotations

from typing_extensions import Optional

from jax import lax, numpy as jnp
from jax.random import uniform
from jaxtyping import Array
from numpyro.distributions import constraints, Distribution
from numpyro.distributions.util import promote_shapes, validate_sample

from ..utils.misc import get_key


class Wysocki2019MassModel(Distribution):
    r"""It is a double side truncated power law distribution, as
    described in equation 7 of the `paper <https://arxiv.org/abs/1805.06442>`__.

    .. math::
        p(m_1,m_2\mid\alpha,k,m_{\text{min}},m_{\text{max}},M_{\text{max}})\propto\frac{m_1^{-\alpha-k}m_2^k}{m_1-m_{\text{min}}}
    """

    arg_constraints = {
        "alpha_m": constraints.real,
        "k": constraints.nonnegative_integer,
        "mmin": constraints.positive,
        "mmax": constraints.positive,
    }

    support = constraints.real_vector

    def __init__(self, alpha_m: float, k: int, mmin: float, mmax: float, *, valid_args=None) -> None:
        r"""Initialize the power law distribution with a lower and upper mass limit.

        :param alpha_m: index of the power law distribution
        :param k: mass ratio power law index
        :param mmin: lower mass limit
        :param mmax: upper mass limit
        :param valid_args: If `True`, validate the input arguments.
        """
        self.alpha_m, self.k, self.mmin, self.mmax = promote_shapes(alpha_m, k, mmin, mmax)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(alpha_m),
            jnp.shape(k),
            jnp.shape(mmin),
            jnp.shape(mmax),
        )
        super(Wysocki2019MassModel, self).__init__(batch_shape=batch_shape, validate_args=valid_args)

    @validate_sample
    def log_prob(self, value):
        return -(self.alpha_m + self.k) * jnp.log(value[0]) + self.k * jnp.log(value[1]) - jnp.log(value[0] - self.mmin)

    def sample(self, key: Optional[Array | int], sample_shape: tuple = ()) -> Array:
        if key is None or isinstance(key, int):
            key = get_key(key)
        m2 = uniform(key=key, minval=self.mmin, maxval=self.mmax, shape=sample_shape + self.batch_shape)
        U = uniform(key=get_key(key), minval=0.0, maxval=1.0, shape=sample_shape + self.batch_shape)
        beta = 1 - (self.k + self.alpha_m)
        conditions = [beta == 0.0, beta != 0.0]
        choices = [
            jnp.exp(U * jnp.log(self.mmax) + (1.0 - U) * jnp.log(m2)),
            jnp.exp(jnp.power(beta, -1.0) * jnp.log(U * jnp.power(self.mmax, beta) + (1.0 - U) * jnp.power(m2, beta))),
        ]
        m1 = jnp.select(conditions, choices)
        return jnp.stack([m1, m2], axis=-1)

if all you need is a likelihood you can use numpyro.factor to add a log_prob to the joint density of the model. there is no need for a full distribution with a sampler etc

1 Like

I feel like you are giving very valuable advice @martinjankowiak , but I’m not sure I can follow. Is there one of the tutorials where I can see your suggestion implemented?

see the docs or search the codebase (e.g. the tests or examples for the string “numpyro.factor”