# How to place bounds on a parameter which depend on other variables

Hi all, I’m very new to using numpyro and wanted to learn it by implementing a simple football model (based on http://web.math.ku.dk/~rolf/teaching/thesis/DixonColes.pdf).

My model code is:

``````def _model(
home_team: jnp.array,
away_team: jnp.array,
num_teams: int,
home_goals: Iterable[int],
away_goals: Iterable[int]
):
)
with numpyro.plate("teams", num_teams):
attack = numpyro.sample(
"attack", dist.Normal(0.0, 0.2)
)
defence = numpyro.sample(
"defence", dist.Normal(0.0, 0.2)
)

expected_home_goals = jnp.exp(home_advantage + attack[home_team] - defence[away_team])
expected_away_goals = jnp.exp(attack[away_team] - defence[home_team])

numpyro.sample(
"home_goals", dist.Poisson(expected_home_goals).to_event(1), obs=home_goals
)
numpyro.sample(
"away_goals", dist.Poisson(expected_away_goals).to_event(1), obs=away_goals
)

# impose bounds on the correlation coefficient
corr_coef_raw = numpyro.sample("corr_coef_raw", dist.Uniform(low=0.0, high=1.0))
upper_bound = min([min(1.0/(expected_home_goals*expected_away_goals)), 1])
lower_bound = max([max(-1.0/expected_home_goals), max(-1.0/expected_away_goals)])
ul_range = upper_bound - lower_bound
corr_coef = lower_bound + corr_coef_raw * ul_range
corr_term = dixon_coles_correlation_term(
home_goals, away_goals, expected_home_goals, expected_away_goals, corr_coef
)
numpyro.factor("correlation_term", corr_term.sum(axis=-1))
``````

In the bottom part of the code, I’m trying to compute the upper and lower bounds (this equation is defined in page 270 of the paper above:

and I want to put a uniform prior over this parameter rho (so I’ve defined a uniform for `corr_coef_raw` and then `corr_coef` is the transformed one which satisfies this bound). So in my code, `expected_home_goals` refers to lambda, `expected_away_goals` refers to mu and `corr_coef` is meant to refer to rho.

The problem with this code is that I get the following error and it seems like Jax doesn’t like this and get a `ConcretizationTypeError`. Any help with this would be much appreciated!

``````File /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/jax/core.py:1189, in concretization_function_error.<locals>.error(self, arg)
1188 def error(self, arg):
-> 1189   raise ConcretizationTypeError(arg, fname_context)

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=2/1)>
The problem arose with the `bool` function.
The error occurred while tracing the function _body_fn at /Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/site-packages/numpyro/infer/hmc_util.py:1001 for while_loop. This concrete value was not available in Python because it depends on the values of the argument 'state'.

Rather than using `min, max`, you need to use `jnp.min`, `jnp.max`.