How to place bounds on a parameter which depend on other variables

Hi all, I’m very new to using numpyro and wanted to learn it by implementing a simple football model (based on http://web.math.ku.dk/~rolf/teaching/thesis/DixonColes.pdf).

My model code is:

def _model(
        home_team: jnp.array,
        away_team: jnp.array,
        num_teams: int,
        home_goals: Iterable[int],
        away_goals: Iterable[int]
    ):
        home_advantage = numpyro.sample(
            "home_advantage", dist.Normal(0.0, 0.2)
        )
        with numpyro.plate("teams", num_teams):
            attack = numpyro.sample(
                "attack", dist.Normal(0.0, 0.2)
            )
            defence = numpyro.sample(
                "defence", dist.Normal(0.0, 0.2)
            )
        
        expected_home_goals = jnp.exp(home_advantage + attack[home_team] - defence[away_team])
        expected_away_goals = jnp.exp(attack[away_team] - defence[home_team])
        
        numpyro.sample(
            "home_goals", dist.Poisson(expected_home_goals).to_event(1), obs=home_goals
        )
        numpyro.sample(
            "away_goals", dist.Poisson(expected_away_goals).to_event(1), obs=away_goals
        )
        
        # impose bounds on the correlation coefficient
        corr_coef_raw = numpyro.sample("corr_coef_raw", dist.Uniform(low=0.0, high=1.0))
        upper_bound = min([min(1.0/(expected_home_goals*expected_away_goals)), 1])
        lower_bound = max([max(-1.0/expected_home_goals), max(-1.0/expected_away_goals)])
        ul_range = upper_bound - lower_bound
        corr_coef = lower_bound + corr_coef_raw * ul_range
        corr_term = dixon_coles_correlation_term(
            home_goals, away_goals, expected_home_goals, expected_away_goals, corr_coef
        )
        numpyro.factor("correlation_term", corr_term.sum(axis=-1))

In the bottom part of the code, I’m trying to compute the upper and lower bounds (this equation is defined in page 270 of the paper above:

and I want to put a uniform prior over this parameter rho (so I’ve defined a uniform for corr_coef_raw and then corr_coef is the transformed one which satisfies this bound). So in my code, expected_home_goals refers to lambda, expected_away_goals refers to mu and corr_coef is meant to refer to rho.

The problem with this code is that I get the following error and it seems like Jax doesn’t like this and get a ConcretizationTypeError. Any help with this would be much appreciated!

File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/core.py:1189, in concretization_function_error.<locals>.error(self, arg)
   1188 def error(self, arg):
-> 1189   raise ConcretizationTypeError(arg, fname_context)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=2/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function _body_fn at /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1001 for while_loop. This concrete value was not available in Python because it depends on the values of the argument 'state'.

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

As an aside: this is an implementation in Stan (soccerstan/dixon-coles.stan at master · Torvaney/soccerstan · GitHub) which I drew inspiration off and trying to translate this into numpyro code.

Rather than using min, max, you need to use jnp.min, jnp.max.

Thanks very much for the quick response - looks like this works :smile: