Primitive sample and dist sample

Hello, I am trying to sample using numpyro.sample from a multivariable distribution so it samples two values m1 and m2, but it returns one number. My first thought was that my code was incorrect but I have checked it my calling the dist.sample it is returning the expected two masses.

Distribution class,

#  Copyright 2023 The GWKokab Authors
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from __future__ import annotations

from jax import lax
from jax import numpy as jnp
from jax.random import uniform
from jaxtyping import Array
from numpyro.distributions import Distribution, constraints
from numpyro.distributions.util import promote_shapes, validate_sample
from typing_extensions import Optional

from ..utils.misc import get_key


class Wysocki2019MassModel(Distribution):
    r"""It is a double side truncated power law distribution, as
    described in equation 7 of the `paper <https://arxiv.org/abs/1805.06442>`__.

    .. math::
        p(m_1,m_2\mid\alpha,k,m_{\text{min}},m_{\text{max}},M_{\text{max}})\propto\frac{m_1^{-\alpha-k}m_2^k}{m_1-m_{\text{min}}}
    """

    arg_constraints = {
        "alpha_m": constraints.real,
        "k": constraints.nonnegative_integer,
        "mmin": constraints.positive,
        "mmax": constraints.positive,
        "Mmax": constraints.positive,
    }

    support = constraints.real_vector

    def __init__(self, alpha_m: float, k: int, mmin: float, mmax: float, Mmax: float, *, valid_args=None) -> None:
        r"""Initialize the power law distribution with a lower and upper mass limit.

        :param alpha_m: index of the power law distribution
        :param k: mass ratio power law index
        :param mmin: lower mass limit
        :param mmax: upper mass limit
        :param Mmax: maximum mass
        :param valid_args: _description_, defaults to `None`
        """
        self.alpha_m, self.k, self.mmin, self.mmax, self.Mmax = promote_shapes(alpha_m, k, mmin, mmax, Mmax)
        batch_shape = lax.broadcast_shapes(
            jnp.shape(alpha_m),
            jnp.shape(k),
            jnp.shape(mmin),
            jnp.shape(mmax),
            jnp.shape(Mmax),
        )
        super(Wysocki2019MassModel, self).__init__(batch_shape=batch_shape, validate_args=valid_args)

    @validate_sample
    def log_prob(self, value):
        return -(self.alpha_m + self.k) * jnp.log(value[0]) + self.k * jnp.log(value[1]) - jnp.log(value[0] - self.mmin)

    def sample(self, key: Optional[Array | int], sample_shape: tuple = ()) -> Array:
        if key is None or isinstance(key, int):
            key = get_key(key)
        m2 = uniform(
            key=key,
            minval=self.mmin,
            maxval=self.mmax,
            shape=sample_shape + self.batch_shape + self.event_shape,
        )
        U = uniform(
            key=get_key(key),
            minval=0.0,
            maxval=1.0,
            shape=sample_shape + self.batch_shape + self.event_shape,
        )
        beta = 1 - (self.k + self.alpha_m)
        conditions = [beta == 0.0, beta != 0.0]
        choices = [
            jnp.exp(U * jnp.log(self.mmax) + (1.0 - U) * jnp.log(m2)),
            jnp.exp(jnp.power(beta, -1.0) * jnp.log(U * jnp.power(self.mmax, beta) + (1.0 - U) * jnp.power(m2, beta))),
        ]
        m1 = jnp.select(conditions, choices)
        return jnp.stack([m1, m2])

def expval_mc(alpha, m_min, m_max, rate):
    model = Wysocki2019MassModel(
        alpha_m=alpha,
        k=0,
        mmin=m_min,
        mmax=m_max,
        Mmax=2 * m_max,
    )
    m1m2 = numpyro.sample("m1m2", model)
    print(m1m2)
    m1m2 = model.sample(jax.random.PRNGKey(0))
    print(m1m2)

expval_mc(0.8, 5.0, 50.0, 1.0)

I tested your code and saw the result

[48.262074 26.5584  ]
[48.97362  23.830566]

So I guess things are working as expected. Note that you might need to modify your code if you want to evaluate sample/log_prob with many samples. Lines like jnp.stack([m1, m2]) need to have axis specified, like jnp.stack([m1, m2], axis=-1).

1 Like