Convert initial parameters for use in `init_to_value`

I’m working with a model that uses numpyro’s reparameterization. I have a set of initial values that I’d like to use for initializing the chains, but I cannot find a method by which to transform these constrained initial values into what is expected by init_to_value.

I have attempted to use unconstrain_fn:

start_vals = {'a': 123, 'b': 0.56, 'c': -95.3, ...} # example starting values
tran_vals = unconstrain_fn(model, (x, y, y_err), {}, start_vals)

but this gives me an error:

Traceback (most recent call last):
  File "/Users/nmearl/code/feadme/.venv/bin/feadme", line 10, in <module>
    sys.exit(run())
             ^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 1161, in __call__
    return self.main(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 1082, in main
    rv = self.invoke(ctx)
         ^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 1443, in invoke
    return ctx.invoke(self.callback, **ctx.params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 788, in invoke
    return __callback(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/src/feadme/cli.py", line 189, in run
    unconstrained_init = unconstrain_fn(part_disk_model, (template, wave, flux, flux_err), {}, starting_values)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/util.py", line 265, in unconstrain_fn
    transforms = get_transforms(model, model_args, model_kwargs, params)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/util.py", line 246, in get_transforms
    transforms, _, _, _ = _get_model_transforms(
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
    model_trace = trace(model).get_trace(*model_args, **model_kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
    self(*args, **kwargs)
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
    return self.fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/src/feadme/compose.py", line 186, in disk_model
    param_mods[samp_name] = numpyro.sample(
                            ^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 250, in sample
    msg = apply_stack(initial_msg)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 55, in apply_stack
    handler.process_message(msg)
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/handlers.py", line 645, in process_message
    new_fn, value = reparam(msg["name"], msg["fn"], msg["value"])
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/reparam.py", line 175, in __call__
    x = numpyro.sample(
        ^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 250, in sample
    msg = apply_stack(initial_msg)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 61, in apply_stack
    default_process_message(msg)
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 32, in default_process_message
    msg["value"], msg["intermediates"] = msg["fn"](
                                         ^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/distributions/distribution.py", line 393, in __call__
    return self.sample_with_intermediates(key, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/distributions/distribution.py", line 351, in sample_with_intermediates
    return self.sample(key, sample_shape=sample_shape), []
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/distributions/truncated.py", line 241, in sample
    assert is_prng_key(key)
           ^^^^^^^^^^^^^^^^
AssertionError

Documentation has several examples of converting unconstrained values to constrained values, but not the other way around. What is the suggested approach to this?

Thanks for any help.

Maybe you need to seed your model?

Thanks @fehiepsi, it seems that seeding does avoid the error. However, I’m not getting the expected output. I’ve constructed a simple example:

import numpyro
from numpyro.infer import MCMC, NUTS
import jax
import numpyro.distributions as dist
from numpyro.infer.reparam import LocScaleReparam
from numpyro.infer import init_to_value
from numpyro.infer.util import unconstrain_fn, constrain_fn

def model():
    x_hi, x_lo = 1000, 2000
    with numpyro.handlers.seed(rng_seed=1):
        with numpyro.handlers.reparam(config={'x': TransformReparam()}):
            base_dist = dist.Uniform(0, 1)
            transforms = [dist.transforms.AffineTransform(x_lo, x_hi - x_lo)]
            unif_param_dist = dist.TransformedDistribution(base_dist, transforms)
            x = numpyro.sample("x", unif_param_dist)

unconstrained_starters = unconstrain_fn(
    model, (), {}, {'x': 1500},
)

print("Unconstrained values:", unconstrained_starters)

constrained_starters = constrain_fn(
    model, (), {}, unconstrained_starters
)

print("Constrained values:", constrained_starters)

kernel = NUTS(model, init_strategy=init_to_value(values=unconstrained_starters))
mcmc = MCMC(kernel, num_warmup=10, num_samples=10)
mcmc.run(jax.random.PRNGKey(0))

But this just prints:

Unconstrained values: {'x': 1500}
Constrained values: {'x': 1500}

It seems like init_to_value wants the values at x_base, the reparameterized site. How, given some initial starting values should I be converting to something understandable by init_to_value?

I dont understand. You are using Normal(0, 1) so unconstrained values should be constrained values.

You’re right, sorry. I’ve updated the code with a more explicit example. The init_to_values fails because it’s looking for x_base in the site name values, which presumably means I need to get the transformed x value and provide that to init_to_values.

I see. Thanks! This seems to be a known issue init_strategy does not work with CircularReparam · Issue #1614 · pyro-ppl/numpyro · GitHub

Apologies, I think I might be missing something – it’s not clear to me how that issue is related to this particular problem?

I guess the issue here is init strategies are not compatible with reparam. There are more discussions at Make poutine.reparam compatible with InitMessenger, poutine.condition, etc. · Issue #2878 · pyro-ppl/pyro · GitHub