Hi all, I’m very new to using numpyro and wanted to learn it by implementing a simple football model (based on http://web.math.ku.dk/~rolf/teaching/thesis/DixonColes.pdf).
My model code is:
def _model(
home_team: jnp.array,
away_team: jnp.array,
num_teams: int,
home_goals: Iterable[int],
away_goals: Iterable[int]
):
home_advantage = numpyro.sample(
"home_advantage", dist.Normal(0.0, 0.2)
)
with numpyro.plate("teams", num_teams):
attack = numpyro.sample(
"attack", dist.Normal(0.0, 0.2)
)
defence = numpyro.sample(
"defence", dist.Normal(0.0, 0.2)
)
expected_home_goals = jnp.exp(home_advantage + attack[home_team] - defence[away_team])
expected_away_goals = jnp.exp(attack[away_team] - defence[home_team])
numpyro.sample(
"home_goals", dist.Poisson(expected_home_goals).to_event(1), obs=home_goals
)
numpyro.sample(
"away_goals", dist.Poisson(expected_away_goals).to_event(1), obs=away_goals
)
# impose bounds on the correlation coefficient
corr_coef_raw = numpyro.sample("corr_coef_raw", dist.Uniform(low=0.0, high=1.0))
upper_bound = min([min(1.0/(expected_home_goals*expected_away_goals)), 1])
lower_bound = max([max(-1.0/expected_home_goals), max(-1.0/expected_away_goals)])
ul_range = upper_bound - lower_bound
corr_coef = lower_bound + corr_coef_raw * ul_range
corr_term = dixon_coles_correlation_term(
home_goals, away_goals, expected_home_goals, expected_away_goals, corr_coef
)
numpyro.factor("correlation_term", corr_term.sum(axis=-1))
In the bottom part of the code, I’m trying to compute the upper and lower bounds (this equation is defined in page 270 of the paper above:
and I want to put a uniform prior over this parameter rho (so I’ve defined a uniform for corr_coef_raw
and then corr_coef
is the transformed one which satisfies this bound). So in my code, expected_home_goals
refers to lambda, expected_away_goals
refers to mu and corr_coef
is meant to refer to rho.
The problem with this code is that I get the following error and it seems like Jax doesn’t like this and get a ConcretizationTypeError
. Any help with this would be much appreciated!
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/core.py:1189, in concretization_function_error.<locals>.error(self, arg)
1188 def error(self, arg):
-> 1189 raise ConcretizationTypeError(arg, fname_context)
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=2/1)>
The problem arose with the `bool` function.
The error occurred while tracing the function _body_fn at /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1001 for while_loop. This concrete value was not available in Python because it depends on the values of the argument 'state'.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
As an aside: this is an implementation in Stan (soccerstan/dixon-coles.stan at master · Torvaney/soccerstan · GitHub) which I drew inspiration off and trying to translate this into numpyro code.