TransformReparam + substitute

I have a model with transformed parameters and I want to fix some of those parameters using the substitute effect handler. It seems passing the transformed parameters to substitute has no effect. I noticed that I can add the suffix _base and substitute the base parameters easily enough, but I’m wondering if there’s a way to substitute the transformed parameters instead without me having to calculate what the underlying base parameters ought to be to get the desired transformed values.

Here’s a simplified example to demonstrate what I mean

def model():
    mu_a = numpyro.sample("mu_a", dist.Normal(0., 5.))
    with numpyro.handlers.reparam(config={"a": TransformReparam()}):
        a = numpyro.sample(
                dist.Normal(0., 1.), AffineTransform(mu_a, 1.)
    return mu_a, a

# what I would like to do
subbed_model = substitute(model, data={"a": 0.0})

# what works
subbed_model = substitute(model, data={"mu_a": 0.0, "a_base": 0.0})

In this example if I want to substitute b I have to fix the location parameter mu_b and then also calculate the corresponding b_base. Ideally I would like to be able to just set the desired value of b. Is there a neat way to do that?

It is tricky with reparam api to find latent values that generate the tramsformed value (for transforme reparam, it is easy but not for other reparams). To fix those determintistic values, I guess you can add a flag to your model to disable reparam, then you can substitute like other latent variables.

Nice, that makes sense, thanks for the suggestion!