Hi all, I’m very new to using numpyro and wanted to learn it by implementing a simple football model (based on http://web.math.ku.dk/~rolf/teaching/thesis/DixonColes.pdf).
My model code is:
def _model( home_team: jnp.array, away_team: jnp.array, num_teams: int, home_goals: Iterable[int], away_goals: Iterable[int] ): home_advantage = numpyro.sample( "home_advantage", dist.Normal(0.0, 0.2) ) with numpyro.plate("teams", num_teams): attack = numpyro.sample( "attack", dist.Normal(0.0, 0.2) ) defence = numpyro.sample( "defence", dist.Normal(0.0, 0.2) ) expected_home_goals = jnp.exp(home_advantage + attack[home_team] - defence[away_team]) expected_away_goals = jnp.exp(attack[away_team] - defence[home_team]) numpyro.sample( "home_goals", dist.Poisson(expected_home_goals).to_event(1), obs=home_goals ) numpyro.sample( "away_goals", dist.Poisson(expected_away_goals).to_event(1), obs=away_goals ) # impose bounds on the correlation coefficient corr_coef_raw = numpyro.sample("corr_coef_raw", dist.Uniform(low=0.0, high=1.0)) upper_bound = min([min(1.0/(expected_home_goals*expected_away_goals)), 1]) lower_bound = max([max(-1.0/expected_home_goals), max(-1.0/expected_away_goals)]) ul_range = upper_bound - lower_bound corr_coef = lower_bound + corr_coef_raw * ul_range corr_term = dixon_coles_correlation_term( home_goals, away_goals, expected_home_goals, expected_away_goals, corr_coef ) numpyro.factor("correlation_term", corr_term.sum(axis=-1))
In the bottom part of the code, I’m trying to compute the upper and lower bounds (this equation is defined in page 270 of the paper above:
and I want to put a uniform prior over this parameter rho (so I’ve defined a uniform for
corr_coef_raw and then
corr_coef is the transformed one which satisfies this bound). So in my code,
expected_home_goals refers to lambda,
expected_away_goals refers to mu and
corr_coef is meant to refer to rho.
The problem with this code is that I get the following error and it seems like Jax doesn’t like this and get a
ConcretizationTypeError. Any help with this would be much appreciated!
File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/core.py:1189, in concretization_function_error.<locals>.error(self, arg) 1188 def error(self, arg): -> 1189 raise ConcretizationTypeError(arg, fname_context) ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool)>with<DynamicJaxprTrace(level=2/1)> The problem arose with the `bool` function. The error occurred while tracing the function _body_fn at /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1001 for while_loop. This concrete value was not available in Python because it depends on the values of the argument 'state'. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError
As an aside: this is an implementation in Stan (soccerstan/dixon-coles.stan at master · Torvaney/soccerstan · GitHub) which I drew inspiration off and trying to translate this into numpyro code.