Hi, I think I’m severely misunderstanding or misusing sample within SVI. Here is (a skeleton of) some of the code.
def model():
sample_shape = (100, )
gamma_prior = dist.Normal(-20, 5).expand(sample_shape)
gammas = numpyro.sample('gammas', gamma_prior)
### below, run simulation with parameters and compute loss with output
.
.
.
numpyro.factor("loss", -total_loss)
guide = AutoBNAFNormal(model, num_flows=1)
optimizer = Adam(step_size=1e-2)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
.
.
### below is training loop
What’s expected is that gammas are drawn from a prior (Gaussian) distribution with the above parameters. However, whatever I do, these samples seem to only come from the standard Gaussian when printing, e.g. it is ignoring the mean and variance I have defined.
hard to say because so much of your code is missing. maybe you’re doing training wrong? e.g. unclear how many training steps you do. it’s also entirely clear where total_loss in numpyro.factor("loss", -total_loss) comes from.
When I am printing my samples though, these are draws from my prior at initialization – so at t=0 before any loss is computed or any training is done. My understanding is that training shouldn’t factor into this? If I’m wrong about this, then that may explain some of my confusion.
I guess what I am after is trying to understand if user-defined priors are actually constructed at initialization, or if they require some evolution when doing SVI in order to be achieved.
Incase this helps! Full code. Note that there is no batching here, otherwise I am in general batching to compute the loss, and vmap / pmap ing everything.
def main(args):
timestr = time.strftime("%Y%m%d-%H%M%S")
if not os.path.exists(f'./svi_results/{timestr}'):
PATH = f'./svi_results/{timestr}'
os.makedirs(PATH)
with open(f'{PATH}/args.pkl', 'wb') as file:
pickle.dump(args.__dict__, file)
def model():
shape = (args.N_species, )
gamma_prior = dist.Normal(-30, 5).expand(shape)
omega_prior = dist.Uniform(0, 1).expand(shape)
gammas = numpyro.sample('gammas', gamma_prior)
omegas = numpyro.sample('omegas', omega_prior)
simulation = StasisSolver(omegas, gammas)
stasis_val = simulation.return_stasis()
matter_abundance = simulation.get_asymptote()
stasis_loss = (stasis_val - args.observed_stasis_val)**2
matter_abundance_loss = (matter_abundance - args.total_matter_abundance)**2
total_loss = stasis_loss + matter_abundance_loss
numpyro.factor("guide", -total_loss)
guide = AutoBNAFNormal(model, num_flows=args.num_flows)
optimizer = Adam(step_size=args.lr)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
best_loss, best_epoch, best_loss_counter = jnp.inf, 0, 0
losses = []
state = svi.init(jax.random.PRNGKey(0))
for i in tqdm(range(args.num_epochs)):
best_loss_counter += 1
state, loss = svi.update(state)
losses.append(loss)
if loss < best_loss:
best_loss = loss
best_epoch = i
params = svi.get_params(state)
with open(f'{PATH}/best_params.pkl', 'wb') as file:
pickle.dump(params, file)
best_loss_counter = 0
if best_loss_counter > 200:
print(f"Early stopping at epoch {i+1} with loss {loss}")
break
printing draws from gamma_prior are samples seemingly from a standard normal (not the distribution I’ve defined), which is my confusion. Thanks
Thanks! Yes, that is true. But this isn’t exactly what I’m after.
Maybe this code (that should run on your machine) will confirm my confusion. The draws that are printed from the line print('gamma prior draws', gammas), which happens during training, do not properly represent the prior distribution they are drawn from. This is my confusion.
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoBNAFNormal
from numpyro.optim import Adam
import jax.numpy as jnp
import jax
def dummy_simulation(omegas, gammas): ## dummy simulation for purposes of sharing code
return jnp.mean(omegas), jnp.mean(gammas)
def model():
shape = (100, )
gamma_prior = dist.Normal(-30, 5).expand(shape) ## prior on gamma is a normal distribution with mean -30 and std 5
omega_prior = dist.Uniform(0, 1).expand(shape) ## prior on omega is a uniform distribution between 0 and 1
gammas = numpyro.sample('gammas', gamma_prior)
print('gamma prior draws', gammas) ### should be drawn from a normal distribution with mean -30 and std 5
omegas = numpyro.sample('omegas', omega_prior)
val_1, val_2 = dummy_simulation(omegas, gammas) ## simulation outputs two values
loss_1 = (val_1 - 1.)**2 ## just a dummy loss
loss_2 = (val_2 - 1.)**2 ## just a dummy loss
total_loss = loss_1 + loss_2
numpyro.factor("guide", -total_loss)
guide = AutoBNAFNormal(model, num_flows=1)
optimizer = Adam(step_size=1e-2)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
best_loss, best_epoch, best_loss_counter = jnp.inf, 0, 0
losses = []
state = svi.init(jax.random.PRNGKey(0))
for i in tqdm(range(1000)):
best_loss_counter += 1
state, loss = svi.update(state)
losses.append(loss)
here is an example output, which is not representative of a mean=-30 Gaussian. I am trying to understand why this is happening.
this is because SVI maximizes the ELBO which is an expectation w.r.t. to the guide=variational distribution, i.e. you’re seeing the model executed w.r.t. samples form the guide. e.g. see this tutorial.