Failed to verify mean field restriction on the guide

Hi,

I have a model with the following structure

# compute bedore everything the Covaraince & Precision matrices 

def model_spl(data=None):

   # priors numpyro.sample sample
   # compute signal
   # return the likelihood
    return numpyro.sample('signal', dist.MultivariateNormal(signal, 
                                                        precision_matrix=P,
                                                        covariance_matrix=C),
                            obs=data)

Then,

guide = autoguide.AutoMultivariateNormal(model_spl)
svi = SVI(model_spl, guide,optimizer,loss=numpyro.infer.TraceMeanField_ELBO(),data=data)
svi_result = svi.run(jax.random.PRNGKey(0), 1000)

And I get the Warning message

UserWarning: Failed to verify mean field restriction on the guide. To eliminate this warning, ensure model and guide sites occur in the same order.

but I was thinking that Autoguide generates de numpyro.param in the same order of the model_spl does with numpyro.sample? I there something wrong in my call to autoguide?

Mean field assumption does not apply to AutoMVN, so you need to use AutoNormal. But the warning is misleading. Could you help me create a github issue for enhancing the warning message? Thanks!

Hello @fehiepsi
Well, I’m a bit piuzzle as here is a adaptation of one of your favorite exemple and AutoMVN & TraceMeanField_ELBO seems ok togeteher

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import expit

from jax import random
import jax.numpy as jnp

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, TraceMeanField_ELBO, autoguide
from numpyro.util import enable_x64

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rc('image', cmap='jet')
mpl.rcParams['font.size'] = 16
mpl.rcParams["font.family"] = "Times New Roman"



# squared exponential kernel
def kernel(X, Z, length, jitter=1.0e-6):
    deltaXsq = jnp.power((X[:, None] - Z) / length, 2.0)
    k = jnp.exp(-0.5 * deltaXsq) + jitter * jnp.eye(X.shape[0])
    return k


def model(X, Y, length=0.2):
    # compute kernel
    k = kernel(X, X, length)

    # sample from gaussian process prior
    f = numpyro.sample(
        "f",
        dist.MultivariateNormal(loc=jnp.zeros(X.shape[0]), covariance_matrix=k),
    )
    # we use a non-standard link function to induce extra non-gaussianity
    numpyro.sample("obs", dist.Bernoulli(logits=jnp.power(f, 3.0)), obs=Y)


# create artificial binary classification dataset
def get_data(N=16):
    np.random.seed(0)
    X = np.linspace(-1, 1, N)
    Y = X + 0.2 * np.power(X, 3.0) + 0.5 * np.power(0.5 + X, 2.0) * np.sin(4.0 * X)
    Y -= np.mean(Y)
    Y /= np.std(Y)
    Y = np.random.binomial(1, expit(Y))

    assert X.shape == (N,)
    assert Y.shape == (N,)

    return X, Y


# helper function for running SVI with a particular autoguide
def run_svi(rng_key, X, Y, guide_family="AutoDiagonalNormal", K=8, loss=None):
    assert guide_family in ["AutoDiagonalNormal", "AutoDAIS", "AutoMultivariateNormal"]

    if guide_family == "AutoDAIS":
        guide = autoguide.AutoDAIS(model, K=K, eta_init=0.02, eta_max=0.5)
        step_size = 5e-4
    elif guide_family == "AutoMultivariateNormal":
        guide = autoguide.AutoMultivariateNormal(model)
        step_size = 3e-3

    optimizer = numpyro.optim.Adam(step_size=step_size)
    svi = SVI(model, guide, optimizer, loss=loss())
    svi_result = svi.run(rng_key, 20_000, X, Y)
    params = svi_result.params

    final_elbo = -loss(num_particles=1000).loss(
        rng_key, params, model, guide, X, Y
    )

    guide_name = guide_family
    if guide_family == "AutoDAIS":
        guide_name += "-{}".format(K)

    print("[{}] final elbo: {:.2f}".format(guide_name, final_elbo))

    return guide.sample_posterior(
        random.PRNGKey(1), params, sample_shape=(1000,)
    )


# helper function for running mcmc
def run_nuts(mcmc_key, args, X, Y):
    mcmc = MCMC(NUTS(model), num_warmup=1000, num_samples=1000)
    mcmc.run(mcmc_key, X, Y)
    mcmc.print_summary()
    return mcmc.get_samples()

enable_x64()
X, Y = get_data()
rng_keys = random.split(random.PRNGKey(0), 4)
run_svi(rng_keys[1], X, Y, guide_family="AutoDAIS", K=8, loss=Trace_ELBO)
run_svi(rng_keys[1], X, Y, guide_family="AutoMultivariateNormal", loss=Trace_ELBO)
run_svi(rng_keys[1], X, Y, guide_family="AutoMultivariateNormal", loss=TraceMeanField_ELBO)

there is no TraceBack for the last call.

By mean field, we assumed q(x,y,z)=q(x)q(y)q(z). Your last comment works when you have 1 latent variable. As mentioned, the warning message is misleading.