Creating a custom distribution in NumPyro

Hi all,

I’ve recently made the switch from PyMC3 to NumPyro for speed reasons, and am loving it so far. A project I was previously working on in PyMC3 required a custom prior distribution on a latent parameter. In PyMC3, I would have set up this model as follows:

import pymc3 as pm

def bailerjones_lpdf(L):
    def lpdf(r):
        return 2*pm.math.log(r) -  3*pm.math.log(L) - (r/L)
    return lpdf

L_ = 800 # in parsec

with pm.Model() as model:
    r = pm.DensityDist('r', bailerjones_lpdf(L_), 
                      transform = pm.distributions.transforms.Log())
    trace = pm.sample()

I’m at a loss on how get something like this to work in NumPyro, and can’t find any tutorials. Can anybody point me in the right direction?

Thanks so much!

1 Like

I think the clearest way is to mimic other distribution implementations:

class CustomDistribution(Distribution):
    support = constraints.positive
    def __init__(self, L):
        self.L = L
        super().__init__(batch_shape=jnp.shape(L), event_shape=())

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
        return bailerjones_lpdf(self.L).lpdf(value)

Alternatively, you can use the shorter one

r = sample("r", dist.ImproperUniform(constraints.positive, (), ()))
numpyro.factor("r_log_prob", bailerjoines_lpdf(L_).lpdf(r))
1 Like

Hi @fehiepsi! Thanks so much for your input, this really helped me understand how to set this distribution up (I had no idea I could just not implement the sample function, for example). My custom distribution now works perfectly. I’ve included the working code below for others working on similar projects. The distribution behaves as expected (with a mode at 2*L).

from numpyro.distributions import constraints
import numpyro as npy
from jax import numpy as jnp

class BJ19_Prior(npy.distributions.Distribution):
    support = constraints.positive
    def __init__(self, L):
        self.L = L
        super().__init__(batch_shape = jnp.shape(L), event_shape=())
        
    def sample(self, key, sample_shape=()):
        raise NotImplementedError
        
    def log_prob(self, value):
        return 2*jnp.log(value) -  3*jnp.log(self.L) - (value/self.L)
    
def model():
    L_ = 800
    r = npy.sample('r', BJ19_Prior(L_))    
3 Likes