I believe I’m having problems similar to this post, @fehiepsi
I have an iterative optimization procedure which acts as my functional model mapping my input data x to my output data y. It’s basically just a for-loop around a few matmuls, dots, etc. It is parametrized by parameters alpha
and beta
. The inside of the loop is it’s own function step()
and to avoid the slow python for-loop I use jax.lax.fori_loop
to speed it up (I need ~1e5 iterations).
def step(x: jnp.ndarray, var_1: jnp.ndarray, var_2: jnp.ndarray, alpha: float, beta: float):
... math operations like matmuls, dots etc...
jnp.true_divide(alpha, beta) # problem seems to stem from here
... more math operations ...
return new_var_1, new_var_2
@jit
def functional_model(x: jnp.ndarray, alpha: float, beta: float, num_steps: int, var_1_init: jnp.ndarray, var_2_init: jnp.ndarray):
# simply an optimized looping over step()
f = partial(step, x=x, alpha=alpha, beta=beta)
def body_fun(i, carry):
# makes it work with fori_loop
var_1, var_2 = carry
var_1, var_2 = f(var_1=var_1, var_2=var_2)
return w, lam
# var_1 is our y_hat
var_1, var_2 = jax.lax.fori_loop(0, num_steps, body_fun, (var_1_init, var_2_init))
return var_1, var_2
These functions work when provided valid inputs. But when I define my model as such and try to render the model, I get an error.
def model(x, y, var_1_init, var_2_init):
num_samples, num_edges = x.shape
a = numpyro.sample('a', dist.Uniform(0, 1e8))
b = numpyro.sample('b', dist.Uniform(0, 1e8))
steepness = numpyro.sample('steepness', dist.Normal(1, 10))
midpoint = numpyro.sample('midpoint', dist.Normal(0, 10))
y_hat, _ = functional_model(x=x,
var_1_init=var_1_init,
var_1_init=var_1_init,
alpha=a, beta=b,
num_steps=int(1e5)
)
logits = -1 * steepness * (y_hat - midpoint)
with numpyro.plate('num_samples', num_samples, dim=-2):
with numpyro.plate('num_edges', num_edges, dim=-1):
return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)
#... load data x and y, init var1 and var2, etc
a = numpyro.render_model(model,
model_args=(x, y, var_1_init, var_2_init),
render_distributions=True,
render_params=True)
a.view()
Running this code gives this error:
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/contextlib.py", line 137, in __exit__
self.gen.throw(typ, value, traceback)
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/core.py", line 991, in new_main
yield main
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/numpyro/ops/provenance.py", line 109, in eval_provenance
out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1]
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
ans = fun.call_wrapped(*in_tracers_)
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/linear_util.py", line 167, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/numpyro/infer/inspect.py", line 377, in get_log_probs
model(*model_args, **model_kwargs)
File "/Users/maxw/file.py", line 264, in model
y_hat, var_2 = functional_model(x=x,
File "/Users/maxw/file.py", line 179, in functional_model
var_1, var_2 = jax.lax.fori_loop(0, steps, body_fun, (var_1_init, var_2_init))
File "/Users/maxw/file.py", line 176, in body_fun
var_1, var_2 = f(var_1=var_1, var_2=var_2)
File "/Users/maxw/file.py", line 156, in functional_model
c = jnp.true_divide(alpha, beta)
File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py", line 258, in true_divide
return lax.div(x1, x2)
TypeError: Axis name _provenance used with inconsistent sizes: frozenset({'a'}) != frozenset({'b'})
Any idea on how to fix this?