Hello all,
I’m building a VAE trained on continuous data. The VAE is built using dm-haiku and the training model using numpyro and its haiku_module
. I’m building the model as follows:
import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
from typing import Tuple, Sequence
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.module import haiku_module
numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)
class Encoder(hk.Module):
"""Encoder model."""
def __init__(self, hidden_dim1: int = 35, hidden_dim2: int = 30, z_dim: int = 10):
super().__init__()
self._hidden_dim1 = hidden_dim1
self._hidden_dim2 = hidden_dim2
self._z_dim = z_dim
self.act = jax.nn.relu
def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: # return tuple of mean, stddev
x = hk.Flatten()(x)
x = hk.Sequential([
hk.Linear(self._hidden_dim1),
self.act,
hk.Linear(self._hidden_dim2),
self.act,
])(x)
mean = hk.Linear(self._z_dim)(x)
log_stddev = hk.Linear(self._z_dim)(x)
stddev = jnp.exp(log_stddev)
return mean, stddev
class Decoder(hk.Module):
"""Decoder model."""
def __init__(self, hidden_dim1: int = 30, hidden_dim2: int = 35, output_dim: Sequence[int] = (100,)):
super().__init__()
self._hidden_dim1 = hidden_dim1
self._hidden_dim2 = hidden_dim2
self._output_shape = output_dim
self.act = jax.nn.relu
def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
output = hk.Sequential([
hk.Linear(self._hidden_dim1),
self.act,
hk.Linear(self._hidden_dim2),
self.act,
hk.Linear(np.prod(self._output_shape))
])(z)
output = jnp.reshape(output, (-1, *self._output_shape))
return output
def vae_model(batch, hidden_dim1, hidden_dim2, z_dim):
batch = jnp.reshape(batch, (batch.shape[0], -1))
batch_dim, out_dim = jnp.shape(batch)
decode = haiku_module(
"decoder",
hk.transform(
Decoder(hidden_dim1=hidden_dim1, hidden_dim2=hidden_dim2, output_dim=(out_dim, 1))
),
input_shape=(batch_dim, z_dim)
)
with numpyro.plate("batch", batch_dim):
z = numpyro.sample("z", dist.Normal(0, 1).expand([z_dim]).to_event(1))
gen_loc = decode(z)
return numpyro.sample("obs", dist.Normal(gen_loc, .1).to_event(1), obs=batch) # allow some noise around observations
When I call Encoder/Decoder, they work as expected with no errors. I have also used this vae_model
successfully before but with a stax neural network and module()
. But when I execute anything to do with this haiku-based model, e.g:
with numpyro.handlers.seed(rng_seed=1):
trace = numpyro.handlers.trace(vae_model).get_trace(
batch=jnp.ones([100, 10]), # test data
hidden_dim1=35,
hidden_dim2=30,
z_dim=10
)
print(numpyro.util.format_shapes(trace))
I get All hk.Modules must be initialized inside an hk.transform.
I don’t understand this error because I have clearly used hk.transform
in haiku_module
, as required.
Any ideas on this weird error?
numpyro==0.9.0
dm-haiku==0.0.6
jax==0.3.1
jaxlib==0.3.0+cuda11.cudnn82