Infer optimizer hyperparameters with NumPyro

Hi! Thanks for this great package!

I am trying to run the following model in NumPyro in pair with Jax.

However it failed with an error

Traceback (most recent call last):
  File "bayes_optim.py", line 120, in <module>
    mcmc.run(rng_key, images=train_images_subset, labels=train_labels_subset)
  File ".../numpyro/infer/mcmc.py", line 682, in run
    states_flat, last_state = partial_map_fn(map_args)
                              ^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../numpyro/infer/mcmc.py", line 443, in _single_chain_mcmc
    new_init_state = self.sampler.init(
                     ^^^^^^^^^^^^^^^^^^
  File ".../numpyro/infer/hmc.py", line 749, in init
    init_params = self._init_state(
                  ^^^^^^^^^^^^^^^^^
  File ".../numpyro/infer/hmc.py", line 693, in _init_state
    ) = initialize_model(
        ^^^^^^^^^^^^^^^^^
  File ".../numpyro/infer/util.py", line 757, in initialize_model
    raise RuntimeError(
RuntimeError: Cannot find valid initial parameters. Please check your model again.
import jax
import jax.numpy as jnp
import numpyro
import tqdm
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, autoguide
from numpyro.handlers import seed
from numpyro.infer import MCMC, NUTS, HMC

import numpy as np
from tensorflow.keras.datasets import mnist
from flax import linen as nn
from flax.training import train_state
import optax

class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = x.reshape(x.shape[0], 28*28)
        x = nn.Dense(features=10)(x)
        return x

def train_nn(learning_rate, train_images, train_labels, num_epochs=10):
    key = jax.random.PRNGKey(0)
    
    model = CNN()
    params = model.init(key, jnp.ones((1, 28, 28)))
    
    optimizer = optax.adam(learning_rate)
    state = train_state.TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer,
    )
    
    @jax.jit
    def train_step(state, batch):
        def loss_fn(params):
            logits = state.apply_fn(params, batch['image'])
            return jnp.mean(optax.softmax_cross_entropy(logits, batch['label']))
        
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grads = grad_fn(state.params)
        return state.apply_gradients(grads=grads), loss

    for _ in tqdm.tqdm(range(num_epochs)):
        for i in range(0, len(train_images), 32):
            batch = {
                'image': train_images[i:i+32],
                'label': jax.nn.one_hot(train_labels[i:i+32], 10)  # One-hot encode the labels
            }
            state, loss = train_step(state, batch)
    
    return state

# Define the Bayesian model
def model(images, labels):
    lr = numpyro.sample("lr", dist.Gamma(1.0, 1.0 / 0.01))
    
    state = train_nn(lr, images, labels)
    
    logits = state.apply_fn(state.params, images)
    numpyro.sample("obs", dist.Categorical(logits=logits), obs=labels)

print("Start Loading Mnist")
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.astype(np.float32) / 255.
test_images = test_images.astype(np.float32) / 255.

n_subset = 1000
train_images_subset = train_images[:n_subset]
train_labels_subset = train_labels[:n_subset]

if __name__ == "__main__":
    rng_key = jax.random.PRNGKey(0)
    # kernel = NUTS(model)
    kernel = HMC(model)
    mcmc = MCMC(kernel, num_warmup=2, num_samples=2, progress_bar=True)
    print("inference started")
    mcmc.run(rng_key, images=train_images_subset, labels=train_labels_subset)
    samples = mcmc.get_samples()

    print(samples)

I am a bit stuck and not sure what I am doing wrong.

I am using the following package versions:

[tool.poetry]
name = "nodecontraction"
version = "0.1.0"
description = ""
authors = ["Mykola Lukashchuk <nikola.lukashuk@gmail.com>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.12"
numpyro = "^0.15.2"
torch = "^2.4.0"
jax = "^0.4.31"
flax = "^0.8.5"
tensorflow = "^2.17.0"
keras = "^3.5.0"
tqdm = "^4.66.5"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

Just a couple of thought:

  • HMC jit model under the hood, so you shouldn’t use tqdm in your model
  • HMC performs grad over latent variables. Do you want to take grad over train_nn w.r.t. lr to get a state? If not, you might want to stop gradient train_nn(stop_gradient(lr), ...)
  • try to inspect why we couldnt find the valid initial params. nan happens? where it happens?
  • Yes, indeed. Thank you.
  • I want to infer the LR state. Can you elaborate on how it’s possible to infer its state without taking the gradient with respect to it? Is there an example of this somewhere? I’m a bit confused.
  • I will try. Thank you.

There is a relevant discussion here Nested inference and site names in numpyro You can play with grad(get_likelihood_from_lr)(lr) wich is numpyro-agnostic. Here get_likelihood_from_lr runs optimizer, grt logits, compute log likelihood from the logits.

1 Like