VAE using numpyro + haiku_model gives hk.transform error

Hello all,

I’m building a VAE trained on continuous data. The VAE is built using dm-haiku and the training model using numpyro and its haiku_module. I’m building the model as follows:

import numpy as np
import jax
import jax.numpy as jnp
import haiku as hk
from typing import Tuple, Sequence
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.module import haiku_module

numpyro.set_platform("cpu")
numpyro.set_host_device_count(1)

class Encoder(hk.Module):
  """Encoder model."""

  def __init__(self, hidden_dim1: int = 35, hidden_dim2: int = 30, z_dim: int = 10):
    super().__init__()
    self._hidden_dim1 = hidden_dim1
    self._hidden_dim2 = hidden_dim2
    self._z_dim = z_dim
    self.act = jax.nn.relu

  def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: # return tuple of mean, stddev
    x = hk.Flatten()(x)
    x = hk.Sequential([
        hk.Linear(self._hidden_dim1),
        self.act,
        hk.Linear(self._hidden_dim2),
        self.act,
    ])(x)

    mean = hk.Linear(self._z_dim)(x)
    log_stddev = hk.Linear(self._z_dim)(x)
    stddev = jnp.exp(log_stddev)

    return mean, stddev


class Decoder(hk.Module):
  """Decoder model."""

  def __init__(self, hidden_dim1: int = 30, hidden_dim2: int = 35, output_dim: Sequence[int] = (100,)):
    super().__init__()
    self._hidden_dim1 = hidden_dim1
    self._hidden_dim2 = hidden_dim2
    self._output_shape = output_dim
    self.act = jax.nn.relu

  def __call__(self, z: jnp.ndarray) -> jnp.ndarray:
    output = hk.Sequential([
        hk.Linear(self._hidden_dim1),
        self.act,
        hk.Linear(self._hidden_dim2),
        self.act,
        hk.Linear(np.prod(self._output_shape))
    ])(z)

    output = jnp.reshape(output, (-1, *self._output_shape))

    return output

def vae_model(batch, hidden_dim1, hidden_dim2, z_dim):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = haiku_module(
        "decoder",
        hk.transform(
          Decoder(hidden_dim1=hidden_dim1, hidden_dim2=hidden_dim2, output_dim=(out_dim, 1))
        ),
        input_shape=(batch_dim, z_dim)
    )
    with numpyro.plate("batch", batch_dim):
        z = numpyro.sample("z", dist.Normal(0, 1).expand([z_dim]).to_event(1))
        gen_loc = decode(z)
        return numpyro.sample("obs", dist.Normal(gen_loc, .1).to_event(1), obs=batch) # allow some noise around observations

When I call Encoder/Decoder, they work as expected with no errors. I have also used this vae_model successfully before but with a stax neural network and module(). But when I execute anything to do with this haiku-based model, e.g:

with numpyro.handlers.seed(rng_seed=1):
  trace = numpyro.handlers.trace(vae_model).get_trace(
      batch=jnp.ones([100, 10]), # test data
      hidden_dim1=35,
      hidden_dim2=30,
      z_dim=10
  )
print(numpyro.util.format_shapes(trace))

I get All hk.Modules must be initialized inside an hk.transform. I don’t understand this error because I have clearly used hk.transform in haiku_module, as required.

Any ideas on this weird error?

numpyro==0.9.0
dm-haiku==0.0.6
jax==0.3.1
jaxlib==0.3.0+cuda11.cudnn82

Hi @theo, please follow the pattern in ProdLDA tutorial. I think Decoder should not be a subclass of hk.Module if you want to transform it. Alternatively, I think you can make a wrapper that constructs a hk.Module, then transform that wrapper. See Haiku basic.

That’s great thanks. For those reading this in the future, you’re right; either use:

  • class Decoder: rather than class Decoder(hk.Module):
  • what I did was wrap the Decoder in the guide.
decode = haiku_module(
   "decoder",
    hk.transform(
       lambda x: Decoder(hidden_dim1=hidden_dim1, hidden_dim2=hidden_dim2, output_dim=(out_dim,))(x)
    ),
    input_shape=(z_dim,)
)

Thanks!