Bug in Sampling from Normal Distribution in SVI?

Hi, I think I’m severely misunderstanding or misusing sample within SVI. Here is (a skeleton of) some of the code.

def model():
        sample_shape = (100, )

        gamma_prior = dist.Normal(-20, 5).expand(sample_shape)
        gammas = numpyro.sample('gammas', gamma_prior)

        ### below, run simulation with parameters and compute loss with output
        . 
        .
        .

        numpyro.factor("loss", -total_loss)

    guide = AutoBNAFNormal(model, num_flows=1)
    optimizer = Adam(step_size=1e-2)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
    .
    .
    ### below is training loop

What’s expected is that gammas are drawn from a prior (Gaussian) distribution with the above parameters. However, whatever I do, these samples seem to only come from the standard Gaussian when printing, e.g. it is ignoring the mean and variance I have defined.

If I draw parameters from the prior this way:

gamma_prior = dist.Normal(-20, 5).expand(sample_shape)
gammas_samples = gamma_prior.sample(jax.random.PRNGKey(1))

The behavior is as expected. What am I doing wrong??

hard to say because so much of your code is missing. maybe you’re doing training wrong? e.g. unclear how many training steps you do. it’s also entirely clear where total_loss in numpyro.factor("loss", -total_loss) comes from.

When I am printing my samples though, these are draws from my prior at initialization – so at t=0 before any loss is computed or any training is done. My understanding is that training shouldn’t factor into this? If I’m wrong about this, then that may explain some of my confusion.

I guess what I am after is trying to understand if user-defined priors are actually constructed at initialization, or if they require some evolution when doing SVI in order to be achieved.

Incase this helps! Full code. Note that there is no batching here, otherwise I am in general batching to compute the loss, and vmap / pmap ing everything.

def main(args):

    timestr = time.strftime("%Y%m%d-%H%M%S")

    if not os.path.exists(f'./svi_results/{timestr}'):
        PATH = f'./svi_results/{timestr}'
        os.makedirs(PATH)
        
    with open(f'{PATH}/args.pkl', 'wb') as file:
        pickle.dump(args.__dict__, file)
        
    def model():
        shape = (args.N_species, )

        gamma_prior = dist.Normal(-30, 5).expand(shape)
        omega_prior = dist.Uniform(0, 1).expand(shape)
        
        gammas = numpyro.sample('gammas', gamma_prior)
        omegas = numpyro.sample('omegas', omega_prior)
        
        simulation = StasisSolver(omegas, gammas)
        stasis_val = simulation.return_stasis()
        matter_abundance = simulation.get_asymptote()
        
        stasis_loss = (stasis_val - args.observed_stasis_val)**2
        matter_abundance_loss = (matter_abundance - args.total_matter_abundance)**2

        total_loss = stasis_loss + matter_abundance_loss
        numpyro.factor("guide", -total_loss)

    guide = AutoBNAFNormal(model, num_flows=args.num_flows)
    optimizer = Adam(step_size=args.lr)
    svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

    best_loss, best_epoch, best_loss_counter = jnp.inf, 0, 0
    losses = []
    state = svi.init(jax.random.PRNGKey(0))

    for i in tqdm(range(args.num_epochs)):
        best_loss_counter += 1
        state, loss = svi.update(state)
        losses.append(loss)
        
        if loss < best_loss:
            best_loss = loss
            best_epoch = i
            params = svi.get_params(state)
            with open(f'{PATH}/best_params.pkl', 'wb') as file:
                pickle.dump(params, file)
            best_loss_counter = 0
            
        if best_loss_counter > 200:
            print(f"Early stopping at epoch {i+1} with loss {loss}")
            break
            

printing draws from gamma_prior are samples seemingly from a standard normal (not the distribution I’ve defined), which is my confusion. Thanks

how are you drawing samples from model

as in after training? I am doing this. If this isn’t what you’re referencing, some more detail would be great

    num_samples = 1000
    samples = guide.sample_posterior(jax.random.PRNGKey(1), params, (num_samples,))

where the params are pickle-loaded from ‘best_params.pkl’ that are saved in training.

after training you get a posterior which is in general different from the prior. so this is expected.

Thanks! Yes, that is true. But this isn’t exactly what I’m after.

Maybe this code (that should run on your machine) will confirm my confusion. The draws that are printed from the line print('gamma prior draws', gammas), which happens during training, do not properly represent the prior distribution they are drawn from. This is my confusion.

import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.optim import Adam
import jax.numpy as jnp
import jax

def dummy_simulation(omegas, gammas): ## dummy simulation for purposes of sharing code
    return jnp.mean(omegas), jnp.mean(gammas)

def model():
    shape = (100, )

    gamma_prior = dist.Normal(-30, 5).expand(shape) ## prior on gamma is a normal distribution with mean -30 and std 5
    omega_prior = dist.Uniform(0, 1).expand(shape) ## prior on omega is a uniform distribution between 0 and 1
    
    gammas = numpyro.sample('gammas', gamma_prior)
    print('gamma prior draws', gammas) ### should be drawn from a normal distribution with mean -30 and std 5
    omegas = numpyro.sample('omegas', omega_prior)
    
    val_1, val_2 = dummy_simulation(omegas, gammas) ## simulation outputs two values
    
    loss_1 = (val_1 - 1.)**2 ## just a dummy loss
    loss_2 = (val_2 - 1.)**2 ## just a dummy loss

    total_loss = loss_1 + loss_2
    numpyro.factor("guide", -total_loss)

guide = AutoBNAFNormal(model, num_flows=1)
optimizer = Adam(step_size=1e-2)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

best_loss, best_epoch, best_loss_counter = jnp.inf, 0, 0
losses = []
state = svi.init(jax.random.PRNGKey(0))

for i in tqdm(range(1000)):
    best_loss_counter += 1
    state, loss = svi.update(state)
    losses.append(loss)

here is an example output, which is not representative of a mean=-30 Gaussian. I am trying to understand why this is happening.

gamma prior draws [-0.77792835 1.5513997 1.4689226 -0.8891797 1.6586142 1.8516874 -1.9143991 -1.3280873 1.602704 1.6196852 -0.48313093 0.60648584 -1.9278011 -1.5421586 1.147028 0.6957445 -0.24402332 1.9511538 -0.41310024 -1.3717875 0.5415411 0.9201145 -0.65974474 -0.60156584 -1.5836191 1.5495195 -1.9793987 -1.6504655 1.0285363 -0.7308345 -1.1712937 1.0799356 1.7527733 -1.7281809 0.89215994 0.4020362 1.6646938 1.8816423 -0.8738308 -1.3663263 1.8639073 -0.3134818 -1.0524917 -1.5728779 1.146738 0.73643875 1.1724048 0.6571684 0.0166111 -1.9081511 -0.19199896 0.73003006 0.2905984 -0.97260857 0.34668446 0.18595839 1.1871386 -0.49424887 1.9507704 -0.7336378 1.2888699 0.49221563 0.03837824 -1.2219481 0.8104315 -0.49491262 -1.0660887 -1.9435906 -0.10712671 0.46022224 1.276413 0.6822219 0.20741844 1.5192184 -0.01124239 0.7681494 1.4568133 -1.1505575 -0.07627535 0.5596223 1.6404142 0.7693238 -1.031949 -1.2256227 1.0243969 0.43393898 -0.20942831 1.5414958 -0.530324 -0.5331273 0.9794178 1.8405337 0.54413795 1.334086 -0.2682891 -0.1226387 -0.71883917 0.75752306 -0.5270543 1.14258 ]

Sorry for the bother. Hopefully this helps clarify what I’m asking.

this is because SVI maximizes the ELBO which is an expectation w.r.t. to the guide=variational distribution, i.e. you’re seeing the model executed w.r.t. samples form the guide. e.g. see this tutorial.