Hi! Thanks for this great package!
I am trying to run the following model in NumPyro in pair with Jax.
However it failed with an error
Traceback (most recent call last):
File "bayes_optim.py", line 120, in <module>
mcmc.run(rng_key, images=train_images_subset, labels=train_labels_subset)
File ".../numpyro/infer/mcmc.py", line 682, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File ".../numpyro/infer/mcmc.py", line 443, in _single_chain_mcmc
new_init_state = self.sampler.init(
^^^^^^^^^^^^^^^^^^
File ".../numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
^^^^^^^^^^^^^^^^^
File ".../numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
^^^^^^^^^^^^^^^^^
File ".../numpyro/infer/util.py", line 757, in initialize_model
raise RuntimeError(
RuntimeError: Cannot find valid initial parameters. Please check your model again.
import jax
import jax.numpy as jnp
import numpyro
import tqdm
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO, autoguide
from numpyro.handlers import seed
from numpyro.infer import MCMC, NUTS, HMC
import numpy as np
from tensorflow.keras.datasets import mnist
from flax import linen as nn
from flax.training import train_state
import optax
class CNN(nn.Module):
@nn.compact
def __call__(self, x):
x = x.reshape(x.shape[0], 28*28)
x = nn.Dense(features=10)(x)
return x
def train_nn(learning_rate, train_images, train_labels, num_epochs=10):
key = jax.random.PRNGKey(0)
model = CNN()
params = model.init(key, jnp.ones((1, 28, 28)))
optimizer = optax.adam(learning_rate)
state = train_state.TrainState.create(
apply_fn=model.apply,
params=params,
tx=optimizer,
)
@jax.jit
def train_step(state, batch):
def loss_fn(params):
logits = state.apply_fn(params, batch['image'])
return jnp.mean(optax.softmax_cross_entropy(logits, batch['label']))
grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params)
return state.apply_gradients(grads=grads), loss
for _ in tqdm.tqdm(range(num_epochs)):
for i in range(0, len(train_images), 32):
batch = {
'image': train_images[i:i+32],
'label': jax.nn.one_hot(train_labels[i:i+32], 10) # One-hot encode the labels
}
state, loss = train_step(state, batch)
return state
# Define the Bayesian model
def model(images, labels):
lr = numpyro.sample("lr", dist.Gamma(1.0, 1.0 / 0.01))
state = train_nn(lr, images, labels)
logits = state.apply_fn(state.params, images)
numpyro.sample("obs", dist.Categorical(logits=logits), obs=labels)
print("Start Loading Mnist")
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.astype(np.float32) / 255.
test_images = test_images.astype(np.float32) / 255.
n_subset = 1000
train_images_subset = train_images[:n_subset]
train_labels_subset = train_labels[:n_subset]
if __name__ == "__main__":
rng_key = jax.random.PRNGKey(0)
# kernel = NUTS(model)
kernel = HMC(model)
mcmc = MCMC(kernel, num_warmup=2, num_samples=2, progress_bar=True)
print("inference started")
mcmc.run(rng_key, images=train_images_subset, labels=train_labels_subset)
samples = mcmc.get_samples()
print(samples)
I am a bit stuck and not sure what I am doing wrong.
I am using the following package versions:
[tool.poetry]
name = "nodecontraction"
version = "0.1.0"
description = ""
authors = ["Mykola Lukashchuk <nikola.lukashuk@gmail.com>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.12"
numpyro = "^0.15.2"
torch = "^2.4.0"
jax = "^0.4.31"
flax = "^0.8.5"
tensorflow = "^2.17.0"
keras = "^3.5.0"
tqdm = "^4.66.5"
[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"