I’m working with a model that uses numpyro’s reparameterization. I have a set of initial values that I’d like to use for initializing the chains, but I cannot find a method by which to transform these constrained initial values into what is expected by init_to_value
.
I have attempted to use unconstrain_fn
:
start_vals = {'a': 123, 'b': 0.56, 'c': -95.3, ...} # example starting values
tran_vals = unconstrain_fn(model, (x, y, y_err), {}, start_vals)
but this gives me an error:
Traceback (most recent call last):
File "/Users/nmearl/code/feadme/.venv/bin/feadme", line 10, in <module>
sys.exit(run())
^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 1161, in __call__
return self.main(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 1082, in main
rv = self.invoke(ctx)
^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 1443, in invoke
return ctx.invoke(self.callback, **ctx.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/click/core.py", line 788, in invoke
return __callback(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/src/feadme/cli.py", line 189, in run
unconstrained_init = unconstrain_fn(part_disk_model, (template, wave, flux, flux_err), {}, starting_values)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/util.py", line 265, in unconstrain_fn
transforms = get_transforms(model, model_args, model_kwargs, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/util.py", line 246, in get_transforms
transforms, _, _, _ = _get_model_transforms(
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/util.py", line 482, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/handlers.py", line 191, in get_trace
self(*args, **kwargs)
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 121, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/src/feadme/compose.py", line 186, in disk_model
param_mods[samp_name] = numpyro.sample(
^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 250, in sample
msg = apply_stack(initial_msg)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 55, in apply_stack
handler.process_message(msg)
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/handlers.py", line 645, in process_message
new_fn, value = reparam(msg["name"], msg["fn"], msg["value"])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/infer/reparam.py", line 175, in __call__
x = numpyro.sample(
^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 250, in sample
msg = apply_stack(initial_msg)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 61, in apply_stack
default_process_message(msg)
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/primitives.py", line 32, in default_process_message
msg["value"], msg["intermediates"] = msg["fn"](
^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/distributions/distribution.py", line 393, in __call__
return self.sample_with_intermediates(key, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/distributions/distribution.py", line 351, in sample_with_intermediates
return self.sample(key, sample_shape=sample_shape), []
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/nmearl/code/feadme/.venv/lib/python3.12/site-packages/numpyro/distributions/truncated.py", line 241, in sample
assert is_prng_key(key)
^^^^^^^^^^^^^^^^
AssertionError
Documentation has several examples of converting unconstrained values to constrained values, but not the other way around. What is the suggested approach to this?
Thanks for any help.