Hierarchical Bayesian Model with data of varying length

Hi @grishabhg,
my usual approach to handle ragged arrays is to zero-pad a single big tensor and use poutine.mask to include only the real observations. I would recommend against the sequential for i in pyro.plate because it is much slower than the vectorized with pyro.plate version. I think something like this should work:

def model(X, y, series_lengths):
    mu = pyro.sample("mu", dist.Normal(0., 5.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
    number_of_traj = y.size()[1]
    with pyro.plate('trajs', number_of_traj):  # traj dim is -1
        k = pyro.sample("k", dist.Normal(mu, sigma))
        mean = k * X
        bigsigma = pyro.sample("bigsigma", dist.Uniform(0., 10.))
        T = series_lengths.max().item()
        t = torch.arange(T).unsqueeze(-1)  # since time dim is -2
        with pyro.plate("data", T):  # time dim is -2
            with poutine.mask(mask=t < series_lengths):
                pyro.sample("obs", dist.Normal(mean, bigsigma), obs=y)

The (t < series_lengths) assumes zero padding at the end. For some forecasting applications I instead need to zero pad at the beginning so I need to use (t >= T - series_lengths).

2 Likes