# Provenance and Jax Control Flow

I believe I’m having problems similar to this post, @fehiepsi

I have an iterative optimization procedure which acts as my functional model mapping my input data x to my output data y. It’s basically just a for-loop around a few matmuls, dots, etc. It is parametrized by parameters `alpha` and `beta`. The inside of the loop is it’s own function `step()` and to avoid the slow python for-loop I use `jax.lax.fori_loop` to speed it up (I need ~1e5 iterations).

``````def step(x: jnp.ndarray, var_1: jnp.ndarray, var_2: jnp.ndarray, alpha: float, beta: float):
... math operations like matmuls, dots etc...
jnp.true_divide(alpha, beta) # problem seems to stem from here
... more math operations ...
return new_var_1, new_var_2
``````
``````@jit
def functional_model(x: jnp.ndarray, alpha: float, beta: float, num_steps: int, var_1_init: jnp.ndarray, var_2_init: jnp.ndarray):
# simply an optimized looping over step()
f = partial(step, x=x, alpha=alpha, beta=beta)

def body_fun(i, carry):
# makes it work with fori_loop
var_1, var_2 = carry
var_1, var_2 = f(var_1=var_1, var_2=var_2)
return w, lam

# var_1 is our y_hat
var_1, var_2 = jax.lax.fori_loop(0, num_steps, body_fun, (var_1_init,  var_2_init))
return var_1, var_2

``````

These functions work when provided valid inputs. But when I define my model as such and try to render the model, I get an error.

``````def model(x, y, var_1_init, var_2_init):
num_samples, num_edges = x.shape

a = numpyro.sample('a', dist.Uniform(0, 1e8))
b = numpyro.sample('b', dist.Uniform(0, 1e8))
steepness = numpyro.sample('steepness', dist.Normal(1, 10))
midpoint = numpyro.sample('midpoint', dist.Normal(0, 10))

y_hat, _ = functional_model(x=x,
var_1_init=var_1_init,
var_1_init=var_1_init,
alpha=a, beta=b,
num_steps=int(1e5)
)

logits = -1 * steepness * (y_hat - midpoint)
with numpyro.plate('num_samples', num_samples, dim=-2):
with numpyro.plate('num_edges', num_edges, dim=-1):
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

#... load data x and y, init var1 and var2, etc
a = numpyro.render_model(model,
model_args=(x, y, var_1_init, var_2_init),
render_distributions=True,
render_params=True)
a.view()
``````

Running this code gives this error:

``````The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/contextlib.py", line 137, in __exit__
self.gen.throw(typ, value, traceback)
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/core.py", line 991, in new_main
yield main
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/numpyro/ops/provenance.py", line 109, in eval_provenance
out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1]
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/numpyro/infer/inspect.py", line 377, in get_log_probs
model(*model_args, **model_kwargs)
File "/Users/maxw/file.py", line 264, in model
y_hat, var_2 = functional_model(x=x,
File "/Users/maxw/file.py", line 179, in functional_model
var_1, var_2 = jax.lax.fori_loop(0, steps, body_fun, (var_1_init, var_2_init))
File "/Users/maxw/file.py", line 176, in body_fun
var_1, var_2 = f(var_1=var_1, var_2=var_2)
File "/Users/maxw/file.py", line 156, in functional_model
c = jnp.true_divide(alpha, beta)
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py", line 258, in true_divide
return lax.div(x1, x2)
TypeError: Axis name _provenance used with inconsistent sizes: frozenset({'a'}) != frozenset({'b'})
``````

Any idea on how to fix this?

assuming your problem is with model rendering and not something else (can you still do inference?) i would suggest as follows:

model rendering in any case doesn’t peer into the details of deterministic functions like `functional_model` so i guess you should just use a dummy model with

`logits = -1 * steepness * (0 - midpoint)`

for model rendering. e.g. you could define the following:

``````def model(x, y, var_1_init, var_2_init, dummy=False):
num_samples, num_edges = x.shape

a = numpyro.sample('a', dist.Uniform(0, 1e8))
b = numpyro.sample('b', dist.Uniform(0, 1e8))
steepness = numpyro.sample('steepness', dist.Normal(1, 10))
midpoint = numpyro.sample('midpoint', dist.Normal(0, 10))

if not dummy:
y_hat, _ = functional_model(x=x,
var_1_init=var_1_init,
var_1_init=var_1_init,
alpha=a, beta=b,
num_steps=int(1e5)
)

logits = -1 * steepness * (y_hat - midpoint)
else:
logits = steepness * midpoint

with numpyro.plate('num_samples', num_samples, dim=-2):
with numpyro.plate('num_edges', num_edges, dim=-1):
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
``````

Ah, thanks for the tip. The dummy functional model works.

Inference runs, but not with this exact code: the `jax.lax.fori_loop` does not seem to support Reverse-mode differentiation so I must use a raw python loop and so it is currently very slow. I am trying to port to ``jax.lax.scan```, but that’s been stumping me (any tips on that front would also be appreciated, but I recognize this is off topic :} ).

afaik you can use the `forward_mode_differentiation=True` argument that is available in some of the inference algorithms, e.g. here

Got it, thanks!

I actually found using SA is significantly faster than any gradient-based MCMC method, likely due to the large overhead of computing gradients here (jax needs to keep track of ~1e5 `step` calls).

yes SA can be a pretty good option for low-dimensional problems