In HMM with SVI, numpyro doesn't find true params, pyro does

Hi, I am trying to implement a relative simply model, with the goal of making it more complex to fit my needs. The simple model is a Gaussian HMM with one state and N measurements. The more complex model has censored, non-Gaussian, time-inhomogeneous measurements with a lot of missing values.

The simple model is:

x[t] ~ N(x[t - 1] * phi, sigma_x), x[0] ~ N(0, sigma_x)
y[t] ~ N(x[t], sigma_y)

where x[t] are scalar and y[t] are length N.

In pyro I am using GaussianHMM as the distribution, but that won’t work for my complex case. In numpyro I am using lax.scan functionality to gain speed when I move to the more complex case.

SVI on the GaussianHMM gets close to the true parameters very quickly, but the numpyro implementation does not. I have also tried other versions (both pyro and numpyro), but they all suffer from the same problem. I am evidently doing something wrong here. My code is below. Hope someone has any ideas.

Simulation part:

import numpy as np
import pandas as pd

T = 500
N = 10

phi = 0.9
sigma_x = 0.03
sigma_y = 0.1

np.random.seed(42)

e = np.random.normal(size=T, scale=sigma_x)

x = np.empty(T)

x[0] = e[0]
for t in range(1, T):
    x[t] = x[t - 1] * phi + e[t]
    
y = x + np.random.normal(size=(N, T), scale=sigma_y)

pd.Series(y.mean(0)).plot()
pd.Series(x).plot()

The pyro implementation using GaussianHMM:

import torch
import pyro
import pyro.distributions as dist

def model(data):
    
    N, T = data.shape
    
    sigma_x = pyro.param('sigma_x', torch.tensor(1.0), constraint=dist.constraints.positive)
    sigma_y = pyro.param('sigma_y', torch.tensor(1.0), constraint=dist.constraints.positive)
    
    init_dist = dist.Normal(0, sigma_x).expand([1]).to_event(1)
    obs_dist = dist.Normal(0, sigma_y).expand([N]).to_event(1)
    trans_dist = dist.Normal(0, sigma_x).expand([1]).to_event(1)
    
    obs_matrix = torch.ones((1, N))
    phi = pyro.param('phi', torch.tensor(0.0))
    trans_matrix = phi.reshape((1, 1))
    
    noise_dist = dist.GaussianHMM(init_dist, trans_matrix, trans_dist, obs_matrix, obs_dist, duration=T)
    pyro.sample('obs', noise_dist, obs=data.T)

pyro.clear_param_store()
pyro.set_rng_seed(42)

guide = pyro.infer.autoguide.AutoNormal(model)
optim = pyro.optim.Adam({'lr': 0.02})

svi = pyro.infer.SVI(model, guide, optim, pyro.infer.Trace_ELBO())
y_torch = torch.tensor(y, dtype=torch.float)

for step in range(501):
    loss = svi.step(y_torch)
    if step % 100 == 0:
        print(f'Epoch {step:4}: Elbo loss: {loss / N:3.2f}')

print(pyro.get_param_store()['phi'].item())
print(pyro.get_param_store()['sigma_x'].item())
print(pyro.get_param_store()['sigma_y'].item())

And the numpyro implementation using lax.scan:

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist

def model(y):
    
    N, T = y.shape
    
    phi = numpyro.param('phi', jnp.array(0.0))
    sigma_x = numpyro.param('sigma_x', jnp.array(1.0), constraint=dist.constraints.positive)
    sigma_y = numpyro.param('sigma_y', jnp.array(1.0), constraint=dist.constraints.positive)
    
    def transition(x, e):
        x_new = x * phi + e
        return x_new, x_new
    
    x0 = numpyro.sample('x0', dist.Normal(np.zeros(1), sigma_x))
    e = numpyro.sample('e', dist.Normal(np.zeros(T), sigma_x))
    _, x = jax.lax.scan(transition, x0, e)
    numpyro.deterministic('x', x)
    numpyro.sample('obs', dist.Normal(x.squeeze(-1), sigma_y), obs=y)

y_jax = jnp.array(y)

guide = numpyro.infer.autoguide.AutoNormal(model)
optim = numpyro.optim.Adam(step_size=0.002)
svi = numpyro.infer.SVI(model, guide, optim, loss=numpyro.infer.Trace_ELBO())
svi_result = svi.run(jax.random.PRNGKey(0), 10000, y_jax)

params = svi_result.params
guide_samples = guide.sample_posterior(jax.random.PRNGKey(1), svi_result.params)

print(params['phi'])
print(params['sigma_x'])
print(params['sigma_y'])

SVI on the GaussianHMM gets close to the true parameters very quickly, but the numpyro implementation does not.

not sure what kind of discrepancy you’re talking about but the GaussianHMM integrates out the gaussian latent variables exactly, where as the numpyro code does not (instead it does mean field variational inference). the latter is an approximation and does not e.g. fully propagate uncertainty from one time step to the next. as such there’s no reason a priori to expect that it performs super well.

Sorry I should have specified the discrepancy. It’s the difference between finding something like phi=0.88 (close to true value of 0.9), and something like 0.6. Your explanation makes a lot of sense, many thanks! I understand that it’s indeed possible to calculate the conditional log_prob for a Gaussian state space model, but not in general.

I tried to switch to MCMC and indeed now I can find the true value. I have also included the additional complexity, primarily a censored student-t observation. Censoring works fine for a normal distribution, but the switch to a student-t somehow pushes the phi to zero and gives me a very noisy process, even when I choose df=1000 (so effectively I have a Normal distribution). Is the switch causing the algorithms to make different choices under the hood? I would expect that since the distributions are numerically pretty much identical for df=1000, there should be no difference.

My code is below for reference.

def model(y_obs, idx_y):
    
    N, L = y_obs.shape
    
    phi = numpyro.sample('phi', dist.Uniform(jnp.array(-1.0), jnp.array(1.0)))
    sigma_x = numpyro.sample('sigma_x', dist.HalfCauchy(jnp.array(1.0)))
    sigma_y = numpyro.sample('sigma_y', dist.HalfCauchy(jnp.array(1.0)))
    
    x_0 = numpyro.sample('x_0', dist.Normal(jnp.zeros(1), sigma_x))
    eps = numpyro.sample('eps', dist.Normal(jnp.zeros(T), sigma_x))
    
    def transition(x, e):
        x_new = x * phi + e
        return x_new, x_new
    
    _, x = jax.lax.scan(transition, x_0, eps)
    numpyro.deterministic('x', x)
    
    x_cmn = jnp.take_along_axis(x.T, idx_y, 1)
    
    y_lat = dist.Normal(x_cmn, sigma_y)
    #y_lat = dist.StudentT(df=2, loc=x_cmn, scale=sigma_y)
    
    with numpyro.handlers.mask(mask=y_obs > 0):
        numpyro.sample('obs', y_lat, obs=y_obs)
        
    with numpyro.handlers.mask(mask=y_obs == 0):
        numpyro.sample('trunc_label', dist.Bernoulli(1 - y_lat.cdf(jnp.nan_to_num(y_obs, 0.0))), obs=y_obs)

i have no idea why you’re seeing that behavior but cdfs can be numerically unstable in some cases maybe it would help to use 64 bit precision.
the form of the inference algorithm is agnostic as to whether you’re using e.g. dist.Normal(...).log_prob(...) or dist.StudentT(...).log_prob(...)

Many thanks for the suggestion. Increasing the precision did not work unfortunately. I think the problem is with the definition of the cdf, which is in terms of the betainc function. The gradient seemed to be too inaccurate for MCMC chain to handle, or something.

My workaround was to define my own gradient on the basis of the cdf source code (the gradient of the cdf is the pdf). I haven’t done a whole lot of testing, but this seems to work:

import jax
import numpyro
import numpyro.distributions as dist

@jax.custom_vjp
def tcdf(x, df):
    # Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions
    # X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5)
    return 0.5 + 0.5 * jnp.sign(x) * (1.0 - jax.scipy.special.betainc(0.5 * df, 0.5, df / (df + x * x)))

tcdf.defvjp(
    lambda x, df: (tcdf(x, df), (x, df)), # forward function
    lambda res, g: (jax.scipy.stats.t.pdf(*res) * g, None)) # backward function; no gradient for df

def cdf_new(self, value):
    return tcdf((value - self.loc) / self.scale, self.df)

dist.StudentT.cdf = cdf_new
1 Like