Provenance and Jax Control Flow

I believe I’m having problems similar to this post, @fehiepsi

I have an iterative optimization procedure which acts as my functional model mapping my input data x to my output data y. It’s basically just a for-loop around a few matmuls, dots, etc. It is parametrized by parameters alpha and beta. The inside of the loop is it’s own function step() and to avoid the slow python for-loop I use jax.lax.fori_loop to speed it up (I need ~1e5 iterations).

def step(x: jnp.ndarray, var_1: jnp.ndarray, var_2: jnp.ndarray, alpha: float, beta: float):
    ... math operations like matmuls, dots etc...
    jnp.true_divide(alpha, beta) # problem seems to stem from here
    ... more math operations ...
    return new_var_1, new_var_2
@jit
def functional_model(x: jnp.ndarray, alpha: float, beta: float, num_steps: int, var_1_init: jnp.ndarray, var_2_init: jnp.ndarray):
    # simply an optimized looping over step()
    f = partial(step, x=x, alpha=alpha, beta=beta)

    def body_fun(i, carry):
        # makes it work with fori_loop
        var_1, var_2 = carry
        var_1, var_2 = f(var_1=var_1, var_2=var_2)
        return w, lam

    # var_1 is our y_hat
    var_1, var_2 = jax.lax.fori_loop(0, num_steps, body_fun, (var_1_init,  var_2_init))
    return var_1, var_2

These functions work when provided valid inputs. But when I define my model as such and try to render the model, I get an error.

def model(x, y, var_1_init, var_2_init):
    num_samples, num_edges = x.shape

    a = numpyro.sample('a', dist.Uniform(0, 1e8))
    b = numpyro.sample('b', dist.Uniform(0, 1e8))
    steepness = numpyro.sample('steepness', dist.Normal(1, 10))
    midpoint = numpyro.sample('midpoint', dist.Normal(0, 10))

    y_hat, _ = functional_model(x=x, 
                                var_1_init=var_1_init,
                                var_1_init=var_1_init,
                                alpha=a, beta=b,
                                num_steps=int(1e5)
                                )

    logits = -1 * steepness * (y_hat - midpoint)
    with numpyro.plate('num_samples', num_samples, dim=-2):
        with numpyro.plate('num_edges', num_edges, dim=-1):
            return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

#... load data x and y, init var1 and var2, etc
a = numpyro.render_model(model,
                         model_args=(x, y, var_1_init, var_2_init),
                         render_distributions=True,
                         render_params=True)
a.view()

Running this code gives this error:

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/contextlib.py", line 137, in __exit__
    self.gen.throw(typ, value, traceback)
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/core.py", line 991, in new_main
    yield main
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/numpyro/ops/provenance.py", line 109, in eval_provenance
    out = partial_eval.trace_to_subjaxpr_dynamic(fun, main, avals)[1]
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/interpreters/partial_eval.py", line 1998, in trace_to_subjaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers_)
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/numpyro/infer/inspect.py", line 377, in get_log_probs
    model(*model_args, **model_kwargs)
  File "/Users/maxw/file.py", line 264, in model
    y_hat, var_2 = functional_model(x=x,
  File "/Users/maxw/file.py", line 179, in functional_model
    var_1, var_2 = jax.lax.fori_loop(0, steps, body_fun, (var_1_init, var_2_init))
  File "/Users/maxw/file.py", line 176, in body_fun
    var_1, var_2 = f(var_1=var_1, var_2=var_2)
  File "/Users/maxw/file.py", line 156, in functional_model
    c = jnp.true_divide(alpha, beta)
  File "/Users/maxw/opt/miniconda3/envs/prob-gsl/lib/python3.9/site-packages/jax/_src/numpy/ufuncs.py", line 258, in true_divide
    return lax.div(x1, x2)
TypeError: Axis name _provenance used with inconsistent sizes: frozenset({'a'}) != frozenset({'b'})

Any idea on how to fix this?

assuming your problem is with model rendering and not something else (can you still do inference?) i would suggest as follows:

model rendering in any case doesn’t peer into the details of deterministic functions like functional_model so i guess you should just use a dummy model with

logits = -1 * steepness * (0 - midpoint)

for model rendering. e.g. you could define the following:

def model(x, y, var_1_init, var_2_init, dummy=False):
    num_samples, num_edges = x.shape

    a = numpyro.sample('a', dist.Uniform(0, 1e8))
    b = numpyro.sample('b', dist.Uniform(0, 1e8))
    steepness = numpyro.sample('steepness', dist.Normal(1, 10))
    midpoint = numpyro.sample('midpoint', dist.Normal(0, 10))

    if not dummy:
      y_hat, _ = functional_model(x=x, 
                                  var_1_init=var_1_init,
                                  var_1_init=var_1_init,
                                  alpha=a, beta=b,
                                  num_steps=int(1e5)
                                  )

        logits = -1 * steepness * (y_hat - midpoint)
    else:
        logits = steepness * midpoint

    with numpyro.plate('num_samples', num_samples, dim=-2):
        with numpyro.plate('num_edges', num_edges, dim=-1):
            return numpyro.sample('obs', dist.Bernoulli(logits=logits), obs=y)

Ah, thanks for the tip. The dummy functional model works.

Inference runs, but not with this exact code: the jax.lax.fori_loop does not seem to support Reverse-mode differentiation so I must use a raw python loop and so it is currently very slow. I am trying to port to ``jax.lax.scan```, but that’s been stumping me (any tips on that front would also be appreciated, but I recognize this is off topic :} ).

afaik you can use the forward_mode_differentiation=True argument that is available in some of the inference algorithms, e.g. here

Got it, thanks!

I actually found using SA is significantly faster than any gradient-based MCMC method, likely due to the large overhead of computing gradients here (jax needs to keep track of ~1e5 step calls).

yes SA can be a pretty good option for low-dimensional problems