I am trying to recreate the enumerate HMM numpyro example with SVI instead of MCMC but am getting the following error.
ValueError: Continuous inference cannot handle discrete sample site 'x'.
Here is my model (taken from the online example)
# x[t-1] --> x[t] --> x[t+1]
# | | |
# V V V
# y[t-1] y[t] y[t+1]
#
# This model includes a plate for the data_dim = 44 keys on the piano. This
# model has two "style" parameters probs_x and probs_y that we'll draw from a
# prior. The latent state is x, and the observed state is y.
def model_1(sequences, lengths, hidden_dim, include_prior=True):
num_sequences, max_length, data_dim = sequences.shape
with mask(mask=include_prior):
probs_x = numpyro.sample(
"probs_x", dist.Dirichlet(0.9 * jnp.eye(hidden_dim) + 0.1).to_event(1)
)
probs_y = numpyro.sample(
"probs_y",
dist.Beta(0.1, 0.9).expand([hidden_dim, data_dim]).to_event(2),
)
def transition_fn(carry, y):
x_prev, t = carry
with numpyro.plate("sequences", num_sequences, dim=-2):
with mask(mask=(t < lengths)[..., None]):
x = numpyro.sample(
"x",
dist.Categorical(probs_x[x_prev]),
infer={"enumerate": "parallel"},
)
with numpyro.plate("tones", data_dim, dim=-1):
numpyro.sample("y", dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=y)
return (x, t + 1), None
x_init = jnp.zeros((num_sequences, 1), dtype=jnp.int32)
# NB swapaxes: we move time dimension of `sequences` to the front to scan over it
scan(transition_fn, (x_init, 0), jnp.swapaxes(sequences, 0, 1))
And code for running SVI on said model
def main(num_samples=1000, hidden_dim=16, truncate=None, num_sequences=None,
kernel='nuts', num_warmup=500, num_chains=1, device='cpu'):
model = model_1
numpyro.set_platform(device)
numpyro.set_host_device_count(num_chains)
_, fetch = load_dataset(JSB_CHORALES, split="train", shuffle=False)
lengths, sequences = fetch()
if num_sequences:
sequences = sequences[0 : num_sequences]
lengths = lengths[0 : num_sequences]
logger.info("-" * 40)
logger.info("Training {} on {} sequences".format(model.__name__, len(sequences)))
# find all the notes that are present at least once in the training set
present_notes = (sequences == 1).sum(0).sum(0) > 0
# remove notes that are never played (we remove 37/88 notes with default args)
sequences = sequences[:, :, present_notes]
if truncate:
lengths = lengths.clip(0, truncate)
sequences = sequences[:, : truncate]
logger.info("Each sequence has shape {}".format(sequences[0].shape))
logger.info("Starting inference...")
rng_key = random.PRNGKey(2)
start = time.time()
# kernel = {"nuts": NUTS, "hmc": HMC}[kernel](model)
# mcmc = MCMC(
# kernel,
# num_warmup=num_warmup,
# num_samples=num_samples,
# num_chains=num_chains,
# progress_bar=False if "NUMPYRO_SPHINXBUILD" in os.environ else True,
# )
# mcmc.run(rng_key, sequences, lengths, hidden_dim)
# mcmc.print_summary()
# logger.info("\nMCMC elapsed time: {}".format(time.time() - start))
optim = Adam({'lr': 0.01, 'betas': [0.8, 0.99]})
elbo = Trace_ELBO(rng_key)
guide = AutoDelta(
poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_"))
)
svi = SVI(model, guide, optim, loss=elbo)
start = time.time()
svi_result = svi.run(rng_key, 25, sequences, lengths, hidden_dim)
logger.info("\nSVI elapsed time: {}".format(time.time() - start))
The model runs (albeit slowly) when I use MCMC and I’ve gotten enumerated HMM’s to work with SVI using pyro. Is there something wrong with my code or is SVI + discrete HMM combination not supported in numpyro? My ultimate goal of exploring numpyro was to speed up runtime through using jax’s scan()
as its documentation makes me think it would be more efficient than writing a python for-loop over each time point in a pyro HMM.
Any help or advice would be greatly appreciated.
Best,
Adam