Nested inference and site names in numpyro

I’m interested in using numpyro to estimate models with “nested” inferences, similar to the nesting structure in the RSA examples with Pyro. I’ve run into some trouble trying to implement a simplified version of this. The error I eventually get is, AssertionError: all sites must have unique names but got mu duplicated.

So I’m wondering, how do I manage the naming of sample sites when inference is nested inside inference?

Minimal example

For a sort of minimal example, say we have a “meter” that takes inputs x and gives an estimate of the population mean based on those inputs. The meter might be biased, which is to say it might have a prior about what mean values are likely. We want to do BDA to infer the prior parameters of the meter.

So if we give it x = [-1,0,1,.5] and it outputs mu_est = -.50, then it has some negative bias (an unbiased estimate is .125).

# generate a vector of observations to be fed into the meter
x = dist.Normal(jnp.ones(10),1).sample(PRNGKey(123))

# define the meter and allow it to be passed a prior parameter
def meterModel(prior_mu, x=None):
    mu = numpyro.sample("mu", dist.Normal(prior_mu, .5))
    sigma = 1.
    with numpyro.plate("obs", x.shape[0]):
        numpyro.sample("obs_x", dist.Normal(mu, sigma), obs=x)    

# do MCMC with this model with mu_prior = -1 and see it produces a biased estimate
mu_prior = -1.
kernel = NUTS(meterModel, target_accept_prob=.80)
posterior = MCMC(kernel, 1_000, 1_000, num_chains=1)
posterior.run(PRNGKey(0), mu_prior, x)
posterior.print_summary()

Now I want to do nested inference to infer the prior that this meter must be using to generate the biased outputs that it does.

The approach I’ve tried is to use laplace approximation for the nested inference model, and use MCMC on the outer model.

def laplace_approx(model, *args):
    guide = AutoLaplaceApproximation(model)
    optimizer = numpyro.optim.Minimize()
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    init_state = svi.init(PRNGKey(1), *args)
    optimal_state, loss = svi.update(init_state, *args)
    return guide.get_posterior(svi.get_params(optimal_state))

def outerModel(x, meter_output=None):
    prior_mu = numpyro.sample("prior_mu", dist.Normal(0,3))
    meter_dist = laplace_approx(meterModel, prior_mu, x)
    
    numpyro.sample("meter_samp", meter_dist, obs = meter_output)

kernel = NUTS(outerModel, target_accept_prob=.80)
posterior = MCMC(kernel, 250, 250, num_chains=1)

posterior.run(PRNGKey(0), x, jnp.array(.22))  

But when I try this I get the error:

AssertionError: all sites must have unique names but got `mu` duplicated

I wouldn’t be surprised if there was more wrong with this code, but I am stuck. How do I separate/manage the naming of the samples sites between the inner and outer model?

@dmpowell Welcome to the effect handlers! You can use handlers.block for this purpose. :wink:

with handlers.block():
    meter_dist = laplace_approx(meterModel, prior_mu, x)
1 Like

Aha! Awesome and thanks for your fast response. Adding that line of code fixes the error, but unfortunately inference still doesn’t seem to be working correctly.

If I write the next sample() statement inside the handlers.block() context then it doesn’t seem to affect inference about mu_prior. If I put it outside the handlers.block() context, then I get another error:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

I potentially addressed my own question by swapping the optimizer and truly using SVI for the inner inference.

I replaced the laplace_approx() function with the function below:

def laplace_approx_svi(model, *args):
    guide = AutoLaplaceApproximation(model)
    optimizer = numpyro.optim.Adam(step_size=1)
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    result = svi.run(PRNGKey(0), 2000, *args, progress_bar=False)
    
    return guide.get_posterior(result.params)

This eliminated the previous error about reverse-mode differentiation. And then the outer inference model is:

def outerModel(x, meter_output=None):
    prior_mu = numpyro.sample("prior_mu", dist.Normal(0,3))

    with handlers.block():
        meter_dist = laplace_approx_svi(meterModel, prior_mu, x)
    numpyro.sample("meter_samp", meter_dist, obs = meter_output)
        
kernel = NUTS(outerModel)
posterior = MCMC(kernel, 500, 500, num_chains=1)

posterior.run(PRNGKey(0), x, jnp.array([.22]))  

This runs and produces results that are clearly influenced by the inner model and data that are passed. Unfortunately they don’t seem at all right. The posterior is a spike at a not-quite-right value (should be -1), and n_eff = .50, so it’s not sampling well to say the least.

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prior_mu     -0.68      0.00     -0.68     -0.68     -0.68      0.50      1.00

Appears I’m too quick to post. I made some further progress by replacing the Adam optimizer with Adagrad and things are working quite sensibly now! For posterity (hopefully), I modified laplace_approx_svi() to:

def laplace_approx_svi(model, *args):
    guide = AutoLaplaceApproximation(model)
    optimizer = numpyro.optim.Adagrad(step_size=.1)
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    result = svi.run(PRNGKey(0), 1000, *args, progress_bar=False)
    
    return guide.get_posterior(result.params)

So I guess that raises another question: Any advice on how to choose optimizers and other details of inference for these kinds of purposes? Outside the MCMC, all of these estimation approaches seemed to work just fine, but they produced wildly different behavior inside the inference loop.

Thanks again!

Do you want to take grad through SVI optimization? If not, you might want to add jax.lax.stop_gradient for the output of SVI result.params. But if doing so, the likelihood meter_dist won’t depend on prior_mu, and MCMC will just return the prior dist.Normal(0,3) for you. On the other hand, taking grad through all SVI steps does not make sense to me. So I’m a bit confused. Probably you want some sort of gradient-free kernel like SampleAdaptive in MCMC, rather than NUTS:

kernel = SA(outerModel)
posterior = MCMC(kernel, 500, 500, num_chains=1)

I’m not 100% sure what I want but I do think I want to take the grad through the inference in the inner model. I definitely want meter_dist to depend on prior_mu, that’s the goal is to have the outer inference is dependent on the inner inference. Following your suggestion, I tried swapping NUTS for SampleAdaptive and got identical results (had to draw more samples for SA though, which was slower).

I guess HMCGibbs might be helpful. There we can perform laplace approximation inside gibbs_fn, without having to worry about the above gradient issue. However, I don’t know how to revise the logic to make it become HMC within Gibbs…

Re SA: this could be more stable because no grad is needed to pass through the SVI inference; typically, you will want to increase the number of samples and setting progress_bar=False and it is really fast.

Sorry, I think I was unclear earlier. Things actually appear to be more-or-less working now—I get pretty reasonable (though somewhat biased) estimates on my little minimal example problem. At very least, the gradient errors are resolved, unless you see something else I am missing that I should be concerned about with this approach. From what you’re saying it sounds like maybe this shouldn’t be working? So maybe there is a problem that isn’t throwing an error but that I should still be worried about?

Here is the full code that is now working (as far as I can tell):


import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import  MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoLaplaceApproximation

import numpy as np
import jax.numpy as jnp
from jax.random import PRNGKey

# generate a vector of observations to be fed into the meter
x = dist.Normal(jnp.ones(10),1).sample(PRNGKey(123))

# define the meter and allow it to be passed a prior parameter
def meterModel(prior_mu, x=None):
    mu = numpyro.sample("mu", dist.Normal(prior_mu, .5))
    sigma = 1.
    with numpyro.plate("obs", x.shape[0]):
        numpyro.sample("obs_x", dist.Normal(mu, sigma), obs=x)    

def laplace_approx_svi(model, *args):
    guide = AutoLaplaceApproximation(model)
    optimizer = numpyro.optim.Adagrad(step_size=.1)
    svi = SVI(model, guide, optimizer, Trace_ELBO())
    result = svi.run(PRNGKey(0), 1000, *args, progress_bar=False)
    
    return guide.get_posterior(result.params)

def outerModel(x, meter_output):
    prior_mu = numpyro.sample("prior_mu", dist.Normal(0,100))

    with handlers.block():
        meter_dist = laplace_approx_svi(meterModel, prior_mu, x)
    numpyro.sample("meter_samp", meter_dist, obs = meter_output)
        

kernel = NUTS(outerModel)
posterior = MCMC(kernel, 500, 500, num_chains=1)
posterior.run(PRNGKey(0), x, jnp.array(-.37))  # <-- estimate if prior_mu = -3
posterior.print_summary()

This returns:

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prior_mu     -2.87      0.89     -2.86     -4.36     -1.51    208.51      1.00

this shouldn’t be working

It should run if it is what you want (under the hood, svi.run uses lax.scan, which supports reverse mode differentiation). I was not sure if taking grad(do 1000 Adam update steps) was what you wanted, so I proposed SA, which does not use grad to sample. If grad(SVI) is part of your logic, then I guess the question for “how to choose optimizers” would be complicated to answer (potential energy geometry, optimizer, gradient, each of them is already hard to deal with). I guess you can fix some inputs and see if LaplaceApproximation loss converges with those tuned step_size, num_steps, and hope that the tuned hyperparameters will help SVI converge for many mu_prior values. Finding a good set of hyperparameters that works for multiple datasets (mu_prior in particular) might be an issue for more complicated models.

I would prefer using Minimize than other SGD optimizers for this small model. If you need grad, you can use NUTS(outerModel, forward_mode_differentiation=True), otherwise, you can use SA(outerModel). For Minimize, it is better to switch to double-precision mode numpyro.enable_x64().

maybe there is a problem that isn’t throwing an error but that I should still be worried about?

I don’t have an idea for such error. Sorry for making confusion.

1 Like

Thanks, this is a really helpful summary. I tried using NUTS(..., forward_mode_differentiation=True) with Minimize and this appears to be the most accurate. But when I tried with SA() I ran into the same error I had seen before:

ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

So far what I’m taking away as for choosing optimizations is that Minimize with the NUTS settings you recommended is the safest bet. For more complex models where it is desirable to use SGD, sounds like (as I feared when I asked) this is an area where there be dragons.

Really appreciate your help and expertise!

1 Like

Oops, the problem with SA seems like a bug. We shouldn’t use grad in SA. Probably the issue happens because we used the same HMC code to find initial parameters in SA. Let me isolate the issue and make the fix.

1 Like

Hi @dmpowell, it is actually a bug and has been fixed in this PR. Thanks for your feedback! It is embarrassing for me to suggest a gradient-free algorithm that is not gradient-free. >___<

It is nice to discuss with you. I hope to use your setup for nested inference in the future. :slight_smile:

1 Like

Glad my bumbling helped!