Hierarchical Bayesian Model with data of varying length

Hi I was working with some hierarchical bayesian model that looks like the partially pooled hierarchical model in the eight schools example. However, I am training the model to predict different trajectories of a feature that come from several different subjects.

In my code, my feature (e.g., time) is in X, its shape is (100, 50) where 100 is the number of time samples and 50 corresponds to 50 different subjects. Also, y has the target output. This code would work when the number of time samples are equivalent between subjects. However, if I have different number of time samples per subject how would I set up the pyro.plate?

I can probably zero pad X but unless I specify the exact length of the time series, it might learn incorrect model.

Can someone suggest a solution to this problem?

kk = np.linspace(0, 10, num=100)
samp_t = np.repeat(kk,50).reshape((100,50),order='F')

samp_traj = np.zeros((100,50))
for i in range(50):
    traj_types = bool(np.random.rand() > .5)
    if traj_type == 1:
        samp_traj[:,i] = 0.2 * samp_t[:,i]
    else:
        samp_traj[:,i] = 5 * samp_t[:,i]

X = torch.tensor(samp_t[:, :50], dtype=torch.float)
y = torch.tensor(samp_traj[:, :50], dtype=torch.float)
# %% Training

@config_enumerate
def model(X, y):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    with pyro.plate('trajs', number_of_traj):
        k = pyro.sample("k", dist.Normal(mu, sigma))
        mean = k * X
        bigsigma = pyro.sample("bigsigma", dist.Uniform(0., 10.))
        data_length = y.size()[0]
        with pyro.plate("data", data_length):

            pyro.sample("obs", dist.Normal(mean, bigsigma), obs=y)


# %% Using SVI
pyro.set_rng_seed(1524)
guide = AutoDiagonalNormal(poutine.block(model, expose=[
                           'mu', 'sigma', 'k', 'bigsigma']))  
svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=TraceEnum_ELBO())
pyro.clear_param_store()
num_iters = 2000
losses = []
j = -1
for i in range(num_iters):
    elbo = svi.step(X, y)
    losses.append(elbo)
    print(elbo)

EDIT: I have tried to change my model like so:

@config_enumerate
def model(X, y, series_lengths):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    k = torch.zeros(50)
    bigsigma = torch.zeros(50)
    mean = torch.zeros(50, 50)

    for i in pyro.plate('trajs', number_of_traj):
        k[i] = pyro.sample("k_{}".format(i), dist.Normal(mu, sigma))
        mean[:, i] = k[i] * X[:, i]
        bigsigma[i] = pyro.sample(
            "bigsigma_{}".format(i), dist.Uniform(0., 10.))
        data_length = series_lengths[i]
        with pyro.plate("data_{}".format(i), data_length):

            pyro.sample("obs_{}".format(i), dist.Normal(
                mean[:data_length, i], bigsigma[i]), obs=y[:data_length, i])


# %% Using SVI
pyro.set_rng_seed(1524)
guide = AutoDiagonalNormal(poutine.block(model, expose=[
                           'mu', 'sigma', 'k', 'bigsigma']))  # AutoDelta(model)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=TraceEnum_ELBO())
series_lengths = np.random.randint(10, 50, 50)
pyro.clear_param_store()
num_iters = 200
losses = []
j = -1
for i in range(num_iters):
    elbo = svi.step(X, y, series_lengths)
    losses.append(elbo)
    print(elbo)

where I send in lengths of different trajectories using an argument called series_length. But it never converges. I am sure what I have done is wrong, could someone suggest a solution please. Should I be using pyro.markov instead of outermost pyro.plate?

Hi @grishabhg,
my usual approach to handle ragged arrays is to zero-pad a single big tensor and use poutine.mask to include only the real observations. I would recommend against the sequential for i in pyro.plate because it is much slower than the vectorized with pyro.plate version. I think something like this should work:

def model(X, y, series_lengths):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    with pyro.plate('trajs', number_of_traj):  # traj dim is -1
        k = pyro.sample("k", dist.Normal(mu, sigma))
        mean = k * X
        bigsigma = pyro.sample("bigsigma", dist.Uniform(0., 10.))
        T = series_lengths.max().item()
        t = torch.arange(T).unsqueeze(-1)  # since time dim is -2
        with pyro.plate("data", T):  # time dim is -2
            with poutine.mask(mask=t < series_lengths):
                pyro.sample("obs", dist.Normal(mean, bigsigma), obs=y)

The (t < series_lengths) assumes zero padding at the end. For some forecasting applications I instead need to zero pad at the beginning so I need to use (t >= T - series_lengths).

Hey @fritzo, thanks for replying to my query. It works like a charm and much much faster. Thanks a lot.

1 Like

I have been trying to do something similar in numpyro, so far without success.
I wanted to fit a gaussian process to multiple trajectories with shared kernel parameters.

Below is my script.
When ragged=False is passed to fit then it runs without error. When ragged=True I get an Incompatible shapes error.

How do I go about this?

import jax.numpy as jnp
from jax.lax import scan, dynamic_slice, slice
import jax.random as random
from jax.experimental import stax

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS, Predictive
import numpy as onp


def get_data():
    N=50
    M=100
    kk = onp.linspace(0, 10, num=M)
    samp_t = onp.repeat(kk,N).reshape((M,N),order='F')

    samp_traj = onp.zeros((M,N))
    for i in range(N):
        samp_traj[:,i] = 0.2 * samp_t[:,i]
    lengths = onp.random.randint(2, M, size=N)
    return samp_traj.T, onp.sin(samp_traj.T), lengths


def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, :, None] - Z[:, None, :]) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[1])[None, :, :]
    return k

def fit(X, y, lens, ragged=False):
    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=1000, num_warmup=1000, num_chains=4)
    mcmc.run(rng_key_,
             X=X,
             y=y,
             lens=lens,
             ragged=ragged
             )
    return mcmc


def model(X, y, lens, ragged=False):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    num_trajs = y.shape[0]
    with numpyro.plate('trajs', num_trajs):
        K = kernel(X, X, var, length, noise)
        if not ragged:
            numpyro.sample("obs", dist.MultivariateNormal(loc=jnp.zeros(X.shape), covariance_matrix=K), obs=y)
        else:
            T = X.shape[1]
            t = jnp.tile(jnp.arange(T), X.shape[0]).reshape(X.shape)
            with numpyro.plate('data', T):
                with numpyro.handlers.mask(mask_array=t<lens[:, None]):
                    numpyro.sample("obs", dist.MultivariateNormal(loc=jnp.zeros(X.shape), covariance_matrix=K), obs=y)

if __name__ == '__main__':
    X, y, lens = get_data()
    mcmc = fit(X, y, lens, ragged=False)

Hi @JMellor, I was thinking of doing something similar. Were you able to find a solution to your problem?

@fritzo can we do something like this with GPRegression module ?
Thanks

@JMellor I think with ragged=True, you will want to replace MultivariateNormal by Normal. MVN’s log probability of a vector is a scalar, hence the mask will not work correctly. In addition, inside plate(size=T), MVN(np.zeros(T)).sample() will return a matrix with shape T x T, which, I think, is not what you want. plate(size=T) + Normal(np.zeros(T)) will return a vector with shape T. Finally, MVN is not smart to marginalize out masked variable, i.e. if z = (x, y) ~ MVN, then we can’t derive MVN.log_prob(y) from MVN.log_prob(z) using mask handler.

1 Like

Hi @fehiepsi, thanks for your answer. Would it be possible to do something similar with GPRegression module in Pyro? i.e., if i wanted to build gp regression for trajectories of varying lengths with hierarchical structure of parameters how would I go about it? I wanted a model for which I would get a population mean value for parameters as well as each trajectory would have its own parameter value. Could you give some pointers.

Thanks

@grishabhg GPR does not support batching so I doubt that you can do it. Assuming that batch GPR is supported, I also don’t know how to deal with trajectories of varying lengths. :frowning: (but maybe there are some references showing how to tackle that issue).

hmmm thats a shame. May be something like GPyTorch can be used. I think it supports batching, but probably not varying length time series.

Yes, GPyTorch supports batching. Basically, you just need to add ellipsis for algebraic operators. Or you can write a code similar to the above comment by @JMellor, with MVN is replaced by Normal. It should work.

1 Like

Right! Thanks @fehiepsi :slight_smile:

Btw, I reopened the batch GP issue https://github.com/pyro-ppl/pyro/issues/1679. IMO, it is a good exercise for whom interested in the GP module. :slight_smile:

Thank you for the replies.

@fehiepsi unfortunately I’m not sure changing to a Normal would retain the characteristics of the model I wanted. But I think I understand that what I want to do is not possible using the MVN interface. I didn’t follow why batch was a necessary condition (edit: actually I think I have understood the requirement for this particular approach now).

@grishabhg I haven’t got round to another solution yet. This snippet was an attempt to scale up the following code, which runs for very small toy datasets.

I wondered if I might be better trying to implement a version of MVN that worked on block diagonal matrices so each block could be a given trajectory. I had also considered instead using something like numpyro.infer.hmc_util.consensus with each trajectory as a submodel (although it wasn’t clear to me how to use that function).

import pandas as pd
import jax.numpy as jnp
import jax.random as random
from jax.experimental import stax

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS
import numpy as onp


def get_data():
    N = 50
    X = 3.*onp.random.random(size=N)
    y = onp.sin(X)
    ind = onp.array([c % 10 for c, i in enumerate(range(N))])

    return pd.DataFrame({'time':X, 'num.value':y, 'serialno':ind})


def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, None] - Z[None, :]) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k


def model(X, y, ind):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    for i in onp.unique(ind):
        inds = ind == i
        X_ = X[inds]
        y_ = y[inds]
        K = kernel(X_, X_, var, length, noise)
        numpyro.sample(f"y{i}", dist.MultivariateNormal(loc=jnp.zeros(X_.shape[0]), covariance_matrix=K), obs=y_)


def fit(data):
    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    numpyro.set_platform('gpu')
    numpyro.set_host_device_count(4)
   
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=1000, num_warmup=1000, num_chains=4)
    mcmc.run(rng_key_,
             X=data['time'].values,
             y=data['num.value'].values,
             ind=data['serialno'].values
             )
    return mcmc


data = get_data()
mcmc = fit(data)