Conditional statement in numpyro

from jax._src.api import T
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random, grad, lax
import jax
from numpy.core.numeric import Inf
import numpyro
import numpyro.distributions as dist
import math
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS, Predictive
import numpy as np
import torch
import pandas as pd
import datetime
import scipy.stats as st

rng_key = random.PRNGKey(101)
rng_key, rng_key_ = random.split(rng_key)
begin_time = datetime.datetime.now()

# DATA = np.array([
#     # user, item, order, rating
# ])

DATA = torch.load("full_data_fixed.pt")
DATA = DATA.numpy()
DATA = jax.numpy.asarray(DATA)
K = len(list(set([int(row[0]) for row in DATA])))
N = len(list(set([int(row[1]) for row in DATA])))
R = len(DATA)


def model():
    a = numpyro.sample("competenceMean", dist.Exponential(0.4))
    b = numpyro.sample("competenceVariance", dist.Exponential(0.4))

    # Sample a cultural competence for each user
    coin_flip = numpyro.sample('coin_flip', dist.Uniform(0, 1))    
    if coin_flip < 0.50:
        with numpyro.plate('compet', K):
            competence = numpyro.sample("competence", dist.Gamma(a, b))
    else:
        competence1 = numpyro.sample("competence1", dist.Gamma(a, b))
        
    # Sample a consensus for each item
    with numpyro.plate('consens', N):
        consensus = numpyro.sample("consensus", dist.Beta(1,1))
    
    # Sample a learning rate for each user
    c = numpyro.sample("learningMean", dist.Exponential(1))
    d = numpyro.sample("learningVariance", dist.Exponential(0.001))

    with numpyro.plate('learningR', K):
        learningRate = numpyro.sample("learningRate", dist.Gamma(c, d))
    
    asymptotic_c = numpyro.sample("asymptotic_c", dist.Gamma(a, b))
    
    with numpyro.plate('biasTerm', K):
        bias = numpyro.sample("bias", dist.Normal(0, 0.01))

   
    with numpyro.plate('data_loop', R) as i:
        mu = consensus[DATA[i, 1].astype(int)] + bias[DATA[i, 0].astype(int)]
        if coin_flip < 0.50:
            precision = competence[DATA[i, 0].astype(int)] + (asymptotic_c - competence[DATA[i, 0].astype(int)]) * (1 - jnp.exp(-learningRate[DATA[i, 0].astype(int)]*DATA[i, 2].astype(int)))
        else:
            precision = competence1 + (asymptotic_c - competence1) * (1 - jnp.exp(-learningRate[DATA[i, 0].astype(int)]*DATA[i, 2].astype(int)))
        rating = numpyro.sample("rating", dist.Normal(mu, 1/precision), obs=DATA[i, 3].astype(float)/10)

I am trying to flip a coin and if it is head, use competence variable for every user. If it is tail, use individual-level competence variable. The reason I am doing this is I would like to see whether we need individual level parameterization. So when a coin lands more tail, it means that individual level parameterization is important. After reading several posts, jax has a different style of implementing if statement. I could not figure it out. I would appreciate your help./

Hi @ngurkan, in JAX this assertion will raise errors because coin_flip will have invalid values. I guess you can write your model as

    with numpyro.plate('compet', K):
        competence = numpyro.sample("competence", dist.Gamma(a, b))
    competence1 = numpyro.sample("competence1", dist.Gamma(a, b))
    competence = jnp.where(coin_flip < 0.5, competence, competence1)

Due to if/else/jnp.where, the potential function will not be continuous and HMC/NUTS algorithms won’t like it. Hopefully, it will work for your problem. Otherwise, rather than drawing a uniform value for coin_flip, you can use a Bernoulli(0.5) prior and perform DiscreteHMCGibbs or enumerated HMC/NUTS for that model with discrete latent variable coin_flip.

1 Like

Thank you very much for a quick return!