```
from jax._src.api import T
import matplotlib.pyplot as plt
import jax.numpy as jnp
from jax import random, grad, lax
import jax
from numpy.core.numeric import Inf
import numpyro
import numpyro.distributions as dist
import math
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS, Predictive
import numpy as np
import torch
import pandas as pd
import datetime
import scipy.stats as st
rng_key = random.PRNGKey(101)
rng_key, rng_key_ = random.split(rng_key)
begin_time = datetime.datetime.now()
# DATA = np.array([
# # user, item, order, rating
# ])
DATA = torch.load("full_data_fixed.pt")
DATA = DATA.numpy()
DATA = jax.numpy.asarray(DATA)
K = len(list(set([int(row[0]) for row in DATA])))
N = len(list(set([int(row[1]) for row in DATA])))
R = len(DATA)
def model():
a = numpyro.sample("competenceMean", dist.Exponential(0.4))
b = numpyro.sample("competenceVariance", dist.Exponential(0.4))
# Sample a cultural competence for each user
coin_flip = numpyro.sample('coin_flip', dist.Uniform(0, 1))
if coin_flip < 0.50:
with numpyro.plate('compet', K):
competence = numpyro.sample("competence", dist.Gamma(a, b))
else:
competence1 = numpyro.sample("competence1", dist.Gamma(a, b))
# Sample a consensus for each item
with numpyro.plate('consens', N):
consensus = numpyro.sample("consensus", dist.Beta(1,1))
# Sample a learning rate for each user
c = numpyro.sample("learningMean", dist.Exponential(1))
d = numpyro.sample("learningVariance", dist.Exponential(0.001))
with numpyro.plate('learningR', K):
learningRate = numpyro.sample("learningRate", dist.Gamma(c, d))
asymptotic_c = numpyro.sample("asymptotic_c", dist.Gamma(a, b))
with numpyro.plate('biasTerm', K):
bias = numpyro.sample("bias", dist.Normal(0, 0.01))
with numpyro.plate('data_loop', R) as i:
mu = consensus[DATA[i, 1].astype(int)] + bias[DATA[i, 0].astype(int)]
if coin_flip < 0.50:
precision = competence[DATA[i, 0].astype(int)] + (asymptotic_c - competence[DATA[i, 0].astype(int)]) * (1 - jnp.exp(-learningRate[DATA[i, 0].astype(int)]*DATA[i, 2].astype(int)))
else:
precision = competence1 + (asymptotic_c - competence1) * (1 - jnp.exp(-learningRate[DATA[i, 0].astype(int)]*DATA[i, 2].astype(int)))
rating = numpyro.sample("rating", dist.Normal(mu, 1/precision), obs=DATA[i, 3].astype(float)/10)
```

I am trying to flip a coin and if it is head, use competence variable for every user. If it is tail, use individual-level competence variable. The reason I am doing this is I would like to see whether we need individual level parameterization. So when a coin lands more tail, it means that individual level parameterization is important. After reading several posts, jax has a different style of implementing if statement. I could not figure it out. I would appreciate your help./