Variable number of parameters to a distribution

Hi,

Disclaimer: I’m a total numpyro newbie, I started using it yesterday. So bear with me.

My use-case is the following. Imagine a series of items, i, that are offered for sale. Each item is offered successively to buyers with probability p_{ij} of being sold to buyer j. The number of time item i is offered should be given by a geometric-like random variable d_i, where each trial’s probability p_{ij} is different:

d_i \sim p_n \prod_{j=1}^{n-1} (1 - p_{ij})

with n the number of time it was offered for sale. This is observed.

Now, the probability p_{ij} depends on several factors like the item’s value and the historical “buying rate” of a buyer. This buying rate, b_j, can be modeled by a Beta distribution and is also observed. For simplicity, we can approximate p_{ij} by:

b_j \sim \text{Beta}(a, b)
p_{ij} \sim \text{sigmoid}(\text{intercept}_i + \text{weight}_i \cdot b_j)

The goal is to fit the parameters of the Beta distribution, a and b, as well as the intercepts and weights. After having done that, I want to be able to modify a and b and infer what would be the changes in the distribution of d_i given the fitted intercepts and weights. Essentially, I want to know what is the effect of the historical buying rates of my pool of buyers on the number of time items are offered for sale by modifying a and b after having fitted the model.

My struggle come from this geometric-like distribution on d_i. There’s nothing of the sort already made, so I had to take a stab at implementing my own distribution. It seems like the below code works for fitting a model, but it doesn’t work for using it to make inference with different a and b.

Some questions:

  1. Is the below code sound? The fact I get a warning about a missing plate statement seems to indicate it’s not.
  2. Is there another way to model this problem such that I would not need to implement a custom distribution?
  3. How can I modify the model so I can explore the effect of alternate a and b, given the fitted intercepts and weights?

Any information is welcome. Thanks in advance.

import numpy as np
import pandas as pd
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import MCMC, NUTS
from jax import numpy as jnp
from jax import scipy as jsp
from jax import random
from numpy.random import default_rng

# Simulate some data for n_items.
n_items = 10
a = 0.4
b = 1.6

rng = default_rng()

probabilities = []
for i in range(n_items):
    ps = []
    while True:
        p = rng.beta(a, b)
        ps.append(p)
        if rng.binomial(1, p):
            break
    probabilities.append(ps)

print(f"number of trial for each RV: {[len(rv) for rv in probabilities]}")


class VarGeometric(numpyro.distributions.Distribution):
    arg_constraints = {"probs": dist.constraints.unit_interval}
    support = dist.constraints.nonnegative_integer

    def __init__(self, probs, validate_args=None):
        self.probs = probs
        super(VarGeometric, self).__init__(
            batch_shape=(jnp.shape(self.probs)[0], ), validate_args=validate_args
        )

    def sample(self, key, sample_shape=()):
        raise NotImplementedError

    def log_prob(self, value):
        # I could not make the following work with jax numpy.
        def logp(probs):
            probs = probs[~np.isnan(probs)]
            return np.log(probs[-1]) + (np.log(1 - probs[:-1])).sum()

        return np.apply_along_axis(logp, 1, self.probs)


def model(probabilities, mask, n_offers):
    a = numpyro.sample("a", dist.HalfNormal(1))
    b = numpyro.sample("b", dist.HalfNormal(1))

    with numpyro.plate('items', len(n_offers), dim=-2):
        # Using a mask works for fitting this model, but it won't work when trying to do inference
        # as the number of buyers that will be offered an item is then unknown.
        probs = numpyro.sample('historical_buying_rate', dist.Beta(a, b).mask(mask), obs=probabilities)

        intercept = numpyro.sample('intercept', dist.Normal(0, 1))
        weight = numpyro.sample('weight', dist.Normal(0, 1))
        modified_probs = jsp.special.expit(intercept + probs * weight)

        numpyro.sample("n_offers", VarGeometric(modified_probs), obs=n_offers)

# Prepare data
n_offers = np.array([len(p) for p in probabilities]).reshape(len(probabilities), 1)
probs_matrix = pd.DataFrame(probabilities).values
mask = jnp.where(~jnp.isnan(probs_matrix), True, False)
print(probs_matrix)
print(n_offers)

# Fit model
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, probs_matrix, mask, n_offers)
mcmc.print_summary()

Update: turns out I had the wrong input to my VarGeometric. I had probs instead of modified_probs. With that modification made, the model errors out with

TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ConcreteArray