Hierarchical Bayesian Model with data of varying length

Hi I was working with some hierarchical bayesian model that looks like the partially pooled hierarchical model in the eight schools example. However, I am training the model to predict different trajectories of a feature that come from several different subjects.

In my code, my feature (e.g., time) is in X, its shape is (100, 50) where 100 is the number of time samples and 50 corresponds to 50 different subjects. Also, y has the target output. This code would work when the number of time samples are equivalent between subjects. However, if I have different number of time samples per subject how would I set up the pyro.plate?

I can probably zero pad X but unless I specify the exact length of the time series, it might learn incorrect model.

Can someone suggest a solution to this problem?

kk = np.linspace(0, 10, num=100)
samp_t = np.repeat(kk,50).reshape((100,50),order='F')

samp_traj = np.zeros((100,50))
for i in range(50):
    traj_types = bool(np.random.rand() > .5)
    if traj_type == 1:
        samp_traj[:,i] = 0.2 * samp_t[:,i]
    else:
        samp_traj[:,i] = 5 * samp_t[:,i]

X = torch.tensor(samp_t[:, :50], dtype=torch.float)
y = torch.tensor(samp_traj[:, :50], dtype=torch.float)
# %% Training

@config_enumerate
def model(X, y):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    with pyro.plate('trajs', number_of_traj):
        k = pyro.sample("k", dist.Normal(mu, sigma))
        mean = k * X
        bigsigma = pyro.sample("bigsigma", dist.Uniform(0., 10.))
        data_length = y.size()[0]
        with pyro.plate("data", data_length):

            pyro.sample("obs", dist.Normal(mean, bigsigma), obs=y)


# %% Using SVI
pyro.set_rng_seed(1524)
guide = AutoDiagonalNormal(poutine.block(model, expose=[
                           'mu', 'sigma', 'k', 'bigsigma']))  
svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=TraceEnum_ELBO())
pyro.clear_param_store()
num_iters = 2000
losses = []
j = -1
for i in range(num_iters):
    elbo = svi.step(X, y)
    losses.append(elbo)
    print(elbo)

EDIT: I have tried to change my model like so:

@config_enumerate
def model(X, y, series_lengths):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    k = torch.zeros(50)
    bigsigma = torch.zeros(50)
    mean = torch.zeros(50, 50)

    for i in pyro.plate('trajs', number_of_traj):
        k[i] = pyro.sample("k_{}".format(i), dist.Normal(mu, sigma))
        mean[:, i] = k[i] * X[:, i]
        bigsigma[i] = pyro.sample(
            "bigsigma_{}".format(i), dist.Uniform(0., 10.))
        data_length = series_lengths[i]
        with pyro.plate("data_{}".format(i), data_length):

            pyro.sample("obs_{}".format(i), dist.Normal(
                mean[:data_length, i], bigsigma[i]), obs=y[:data_length, i])


# %% Using SVI
pyro.set_rng_seed(1524)
guide = AutoDiagonalNormal(poutine.block(model, expose=[
                           'mu', 'sigma', 'k', 'bigsigma']))  # AutoDelta(model)

svi = SVI(model,
          guide,
          optim.Adam({"lr": .01}),
          loss=TraceEnum_ELBO())
series_lengths = np.random.randint(10, 50, 50)
pyro.clear_param_store()
num_iters = 200
losses = []
j = -1
for i in range(num_iters):
    elbo = svi.step(X, y, series_lengths)
    losses.append(elbo)
    print(elbo)

where I send in lengths of different trajectories using an argument called series_length. But it never converges. I am sure what I have done is wrong, could someone suggest a solution please. Should I be using pyro.markov instead of outermost pyro.plate?

Hi @grishabhg,
my usual approach to handle ragged arrays is to zero-pad a single big tensor and use poutine.mask to include only the real observations. I would recommend against the sequential for i in pyro.plate because it is much slower than the vectorized with pyro.plate version. I think something like this should work:

def model(X, y, series_lengths):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    with pyro.plate('trajs', number_of_traj):  # traj dim is -1
        k = pyro.sample("k", dist.Normal(mu, sigma))
        mean = k * X
        bigsigma = pyro.sample("bigsigma", dist.Uniform(0., 10.))
        T = series_lengths.max().item()
        t = torch.arange(T).unsqueeze(-1)  # since time dim is -2
        with pyro.plate("data", T):  # time dim is -2
            with poutine.mask(mask=t < series_lengths):
                pyro.sample("obs", dist.Normal(mean, bigsigma), obs=y)

The (t < series_lengths) assumes zero padding at the end. For some forecasting applications I instead need to zero pad at the beginning so I need to use (t >= T - series_lengths).

2 Likes

Hey @fritzo, thanks for replying to my query. It works like a charm and much much faster. Thanks a lot.

1 Like

I have been trying to do something similar in numpyro, so far without success.
I wanted to fit a gaussian process to multiple trajectories with shared kernel parameters.

Below is my script.
When ragged=False is passed to fit then it runs without error. When ragged=True I get an Incompatible shapes error.

How do I go about this?

import jax.numpy as jnp
from jax.lax import scan, dynamic_slice, slice
import jax.random as random
from jax.experimental import stax

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS, Predictive
import numpy as onp


def get_data():
    N=50
    M=100
    kk = onp.linspace(0, 10, num=M)
    samp_t = onp.repeat(kk,N).reshape((M,N),order='F')

    samp_traj = onp.zeros((M,N))
    for i in range(N):
        samp_traj[:,i] = 0.2 * samp_t[:,i]
    lengths = onp.random.randint(2, M, size=N)
    return samp_traj.T, onp.sin(samp_traj.T), lengths


def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, :, None] - Z[:, None, :]) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[1])[None, :, :]
    return k

def fit(X, y, lens, ragged=False):
    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=1000, num_warmup=1000, num_chains=4)
    mcmc.run(rng_key_,
             X=X,
             y=y,
             lens=lens,
             ragged=ragged
             )
    return mcmc


def model(X, y, lens, ragged=False):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    num_trajs = y.shape[0]
    with numpyro.plate('trajs', num_trajs):
        K = kernel(X, X, var, length, noise)
        if not ragged:
            numpyro.sample("obs", dist.MultivariateNormal(loc=jnp.zeros(X.shape), covariance_matrix=K), obs=y)
        else:
            T = X.shape[1]
            t = jnp.tile(jnp.arange(T), X.shape[0]).reshape(X.shape)
            with numpyro.plate('data', T):
                with numpyro.handlers.mask(mask_array=t<lens[:, None]):
                    numpyro.sample("obs", dist.MultivariateNormal(loc=jnp.zeros(X.shape), covariance_matrix=K), obs=y)

if __name__ == '__main__':
    X, y, lens = get_data()
    mcmc = fit(X, y, lens, ragged=False)

Hi @JMellor, I was thinking of doing something similar. Were you able to find a solution to your problem?

@fritzo can we do something like this with GPRegression module ?
Thanks

@JMellor I think with ragged=True, you will want to replace MultivariateNormal by Normal. MVN’s log probability of a vector is a scalar, hence the mask will not work correctly. In addition, inside plate(size=T), MVN(np.zeros(T)).sample() will return a matrix with shape T x T, which, I think, is not what you want. plate(size=T) + Normal(np.zeros(T)) will return a vector with shape T. Finally, MVN is not smart to marginalize out masked variable, i.e. if z = (x, y) ~ MVN, then we can’t derive MVN.log_prob(y) from MVN.log_prob(z) using mask handler.

1 Like

Hi @fehiepsi, thanks for your answer. Would it be possible to do something similar with GPRegression module in Pyro? i.e., if i wanted to build gp regression for trajectories of varying lengths with hierarchical structure of parameters how would I go about it? I wanted a model for which I would get a population mean value for parameters as well as each trajectory would have its own parameter value. Could you give some pointers.

Thanks

@grishabhg GPR does not support batching so I doubt that you can do it. Assuming that batch GPR is supported, I also don’t know how to deal with trajectories of varying lengths. :frowning: (but maybe there are some references showing how to tackle that issue).

hmmm thats a shame. May be something like GPyTorch can be used. I think it supports batching, but probably not varying length time series.

Yes, GPyTorch supports batching. Basically, you just need to add ellipsis for algebraic operators. Or you can write a code similar to the above comment by @JMellor, with MVN is replaced by Normal. It should work.

1 Like

Right! Thanks @fehiepsi :slight_smile:

Btw, I reopened the batch GP issue Batching Gaussian Processes · Issue #1679 · pyro-ppl/pyro · GitHub. IMO, it is a good exercise for whom interested in the GP module. :slight_smile:

Thank you for the replies.

@fehiepsi unfortunately I’m not sure changing to a Normal would retain the characteristics of the model I wanted. But I think I understand that what I want to do is not possible using the MVN interface. I didn’t follow why batch was a necessary condition (edit: actually I think I have understood the requirement for this particular approach now).

@grishabhg I haven’t got round to another solution yet. This snippet was an attempt to scale up the following code, which runs for very small toy datasets.

I wondered if I might be better trying to implement a version of MVN that worked on block diagonal matrices so each block could be a given trajectory. I had also considered instead using something like numpyro.infer.hmc_util.consensus with each trajectory as a submodel (although it wasn’t clear to me how to use that function).

import pandas as pd
import jax.numpy as jnp
import jax.random as random
from jax.experimental import stax

import numpyro
import numpyro.distributions as dist
import numpyro.optim as optim
from numpyro.infer import MCMC, NUTS
import numpy as onp


def get_data():
    N = 50
    X = 3.*onp.random.random(size=N)
    y = onp.sin(X)
    ind = onp.array([c % 10 for c, i in enumerate(range(N))])

    return pd.DataFrame({'time':X, 'num.value':y, 'serialno':ind})


def kernel(X, Z, var, length, noise, jitter=1.0e-6, include_noise=True):
    deltaXsq = jnp.power((X[:, None] - Z[None, :]) / length, 2.0)
    k = var * jnp.exp(-0.5 * deltaXsq)
    if include_noise:
        k += (noise + jitter) * jnp.eye(X.shape[0])
    return k


def model(X, y, ind):
    var = numpyro.sample("kernel_var", dist.LogNormal(0.0, 10.0))
    noise = numpyro.sample("kernel_noise", dist.LogNormal(0.0, 10.0))
    length = numpyro.sample("kernel_length", dist.LogNormal(0.0, 10.0))
    for i in onp.unique(ind):
        inds = ind == i
        X_ = X[inds]
        y_ = y[inds]
        K = kernel(X_, X_, var, length, noise)
        numpyro.sample(f"y{i}", dist.MultivariateNormal(loc=jnp.zeros(X_.shape[0]), covariance_matrix=K), obs=y_)


def fit(data):
    rng_key = random.PRNGKey(0)
    rng_key, rng_key_ = random.split(rng_key)

    numpyro.set_platform('gpu')
    numpyro.set_host_device_count(4)
   
    nuts = NUTS(model)
    mcmc = MCMC(nuts, num_samples=1000, num_warmup=1000, num_chains=4)
    mcmc.run(rng_key_,
             X=data['time'].values,
             y=data['num.value'].values,
             ind=data['serialno'].values
             )
    return mcmc


data = get_data()
mcmc = fit(data)

Sorry to revive this old thread, but I am having trouble implementing the suggested solution to ragged plates. I’m new to the numpyro/pyro universe, so I’m sure I’m making an elementary mistake and I’d appreciate it if anyone could bring it to my attention!

My goal is similar to the OP and I’ve tried to provide a minimally reproducible example. My model assumes that two samples of sizes n1 and n2 are drawn from two different beta distributions.

Note, I assume only two samples for this example, but in application this number will be much larger, which is why I’m interested in the speed and scaling capacity promised by numpyro!

Adapting the mask method described above, I have the following model (with comments identifying what I think is happening in each component).

def model(sample_sizes, y = None):
    
    # Define the number of groups from which we've attained samples
    n_groups =  sample_sizes.shape[0]
    
    # Initialize a plate for each group
    with numpyro.plate("n_groups", n_groups):
        
        # Set priors on a and b within each group
        a = numpyro.sample('a', dist.Gamma(.5, .5))
        b =  numpyro.sample('b', dist.Gamma(3, 3))
        
        # Calculate maximum group size
        I = sample_sizes.max().item()
        # Create a range of 0:I
        i = torch.arange(I).unsqueeze(-1).numpy()
        
        # Initialize a plate for each observation
        with numpyro.plate('data', I):
            # Mask observations that exceed sample size in the respective plate
            with numpyro.handlers.mask(mask = i < sample_sizes): 
                # Estimate y_hat
                y_hat = numpyro.sample('y_hat', dist.Beta(a, b), obs=y)

Then I simulate data as follows:

# Generate Data

# specify number of observations for each group
n1 = 400
n2 = 350

# Sample observations from different beta distributions for each group, given specified sample size
y1 = beta.rvs(.5, .5, size=n1)
y2 = beta.rvs(3, 1, size=n2)

nan_pads = numpy.empty((n1-n2)) * numpy.nan
y2 = np.concatenate([y2, nan_pads])
# Combine samples
y = np.array([y1, y2], order='K').T

# Count non-nans to verify data
nan_count = sum(np.isnan(y), 0)

assert nan_count[0] == 0
assert nan_count[1] == n1-n2

sample_sizes =  np.array([n1, n2])

And, finally, estimate the model.

num_warmup, num_samples = 500, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains = 3)

start = time.time()
mcmc.run(rng_key, sample_sizes=sample_sizes, y = y)
end = time.time()
mcmc.print_summary()

This model is completely divergent:

         mean       std    median      5.0%     95.0%     n_eff     r_hat
      a[0]      2.03      0.39      1.89      1.63      2.56      1.50  33012.44
      a[1]      1.55      0.95      1.95      0.23      2.45      1.50  80650.39
      b[0]      0.93      0.91      0.41      0.18      2.21      1.50 199630.92
      b[1]      1.89      1.52      1.69      0.14      3.84      1.50 225342.38

Number of divergences: 6000

However, when I set n1=n2, the model converges on the expected parameters.

So, to recap: the model seems to work with equal group sizes, but when n1!=n2, something goes very wrong, though no error is thrown.

@joehoover88 I guess although handlers.mask masked out log_prob of those nan values, the gradient is still propagated incorrectly (see this jax faq). A solution is to replace those nan by some valid values, e.g. 0.5 for Beta distribution. Could you confirm?

I have seen this issue several times. It would be nice to add a mechanism to automatically replace those invalid values by some valid values under mask handler. Could you make a github issue for this? I’ll discuss with other devs to provide a solution. Thanks!

@fehiepsi, thank you so much for the quick response! I’ll open an issue today.

Your suggestion worked! I can also confirm that the specific value used to represent nan does not seem to have an effect on parameter estimates, which is consistent with your note about handlers (i.e. the masking seems to be working properly in my code).

For reproducibility, here’s my updated example with seeds to evaluate invariance over nan values.

# Generate Data

# specify number of observations for each group
n1 = 2000
n2 = 1000

# Sample observations from different beta distributions for each group, given specified sample size
np.random.seed(seed=123)
y1 = beta.rvs(.5, .5, size=n1)
np.random.seed(seed=321)
y2 = beta.rvs(3, 1, size=n2)

pads = numpy.ones((n1-n2)) * .01
y2 = np.concatenate([y2, pads])

# Combine samples
y = np.array([y1, y2], order='K').T


sample_sizes =  np.array([n1, n2])

def model(sample_sizes, y = None):
    
    # Define the number of groups from which we've attained samples
    n_groups =  sample_sizes.shape[0]
    
    # Initialize a plate for each group
    with numpyro.plate("n_groups", n_groups):
        
        # Set priors on a and b within each group
        a = numpyro.sample('a', dist.Gamma(.5, .5))
        b =  numpyro.sample('b', dist.Gamma(3, 3))
        
        # Calculate maximum group size
        I = sample_sizes.max().item()
        # Create a range of 0:I
        i = torch.arange(I).unsqueeze(-1).numpy()
        
        # Initialize a plate for each observation
        with numpyro.plate('data', I):
            # Mask observations that exceed sample size in the respective plate
            with numpyro.handlers.mask(mask = i < sample_sizes): 
                # Estimate y_hat
                y_hat = numpyro.sample('y_hat', dist.Beta(a, b), obs=y)

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

num_warmup, num_samples = 500, 2000

# Run NUTS.
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples, num_chains = 3)

start = time.time()
mcmc.run(rng_key_, sample_sizes=sample_sizes, y = y)
end = time.time()
mcmc.print_summary()
1 Like