Hi, I am trying to implement a relative simply model, with the goal of making it more complex to fit my needs. The simple model is a Gaussian HMM with one state and N measurements. The more complex model has censored, non-Gaussian, time-inhomogeneous measurements with a lot of missing values.
The simple model is:
x[t] ~ N(x[t - 1] * phi, sigma_x), x[0] ~ N(0, sigma_x)
y[t] ~ N(x[t], sigma_y)
where x[t] are scalar and y[t] are length N.
In pyro I am using GaussianHMM as the distribution, but that won’t work for my complex case. In numpyro I am using lax.scan functionality to gain speed when I move to the more complex case.
SVI on the GaussianHMM gets close to the true parameters very quickly, but the numpyro implementation does not. I have also tried other versions (both pyro and numpyro), but they all suffer from the same problem. I am evidently doing something wrong here. My code is below. Hope someone has any ideas.
Simulation part:
import numpy as np
import pandas as pd
T = 500
N = 10
phi = 0.9
sigma_x = 0.03
sigma_y = 0.1
np.random.seed(42)
e = np.random.normal(size=T, scale=sigma_x)
x = np.empty(T)
x[0] = e[0]
for t in range(1, T):
x[t] = x[t - 1] * phi + e[t]
y = x + np.random.normal(size=(N, T), scale=sigma_y)
pd.Series(y.mean(0)).plot()
pd.Series(x).plot()
The pyro implementation using GaussianHMM:
import torch
import pyro
import pyro.distributions as dist
def model(data):
N, T = data.shape
sigma_x = pyro.param('sigma_x', torch.tensor(1.0), constraint=dist.constraints.positive)
sigma_y = pyro.param('sigma_y', torch.tensor(1.0), constraint=dist.constraints.positive)
init_dist = dist.Normal(0, sigma_x).expand([1]).to_event(1)
obs_dist = dist.Normal(0, sigma_y).expand([N]).to_event(1)
trans_dist = dist.Normal(0, sigma_x).expand([1]).to_event(1)
obs_matrix = torch.ones((1, N))
phi = pyro.param('phi', torch.tensor(0.0))
trans_matrix = phi.reshape((1, 1))
noise_dist = dist.GaussianHMM(init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist, duration=T)
pyro.sample('obs', noise_dist, obs=data.T)
pyro.clear_param_store()
pyro.set_rng_seed(42)
guide = pyro.infer.autoguide.AutoNormal(model)
optim = pyro.optim.Adam({'lr': 0.02})
svi = pyro.infer.SVI(model, guide, optim, pyro.infer.Trace_ELBO())
y_torch = torch.tensor(y, dtype=torch.float)
for step in range(501):
loss = svi.step(y_torch)
if step % 100 == 0:
print(f'Epoch {step:4}: Elbo loss: {loss / N:3.2f}')
print(pyro.get_param_store()['phi'].item())
print(pyro.get_param_store()['sigma_x'].item())
print(pyro.get_param_store()['sigma_y'].item())
And the numpyro implementation using lax.scan:
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
def model(y):
N, T = y.shape
phi = numpyro.param('phi', jnp.array(0.0))
sigma_x = numpyro.param('sigma_x', jnp.array(1.0), constraint=dist.constraints.positive)
sigma_y = numpyro.param('sigma_y', jnp.array(1.0), constraint=dist.constraints.positive)
def transition(x, e):
x_new = x * phi + e
return x_new, x_new
x0 = numpyro.sample('x0', dist.Normal(np.zeros(1), sigma_x))
e = numpyro.sample('e', dist.Normal(np.zeros(T), sigma_x))
_, x = jax.lax.scan(transition, x0, e)
numpyro.deterministic('x', x)
numpyro.sample('obs', dist.Normal(x.squeeze(-1), sigma_y), obs=y)
y_jax = jnp.array(y)
guide = numpyro.infer.autoguide.AutoNormal(model)
optim = numpyro.optim.Adam(step_size=0.002)
svi = numpyro.infer.SVI(model, guide, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 10000, y_jax)
params = svi_result.params
guide_samples = guide.sample_posterior(jax.random.PRNGKey(1), svi_result.params)
print(params['phi'])
print(params['sigma_x'])
print(params['sigma_y'])