Vectorizing sequential Pyro plates for hierarchical models with observations containing multiple features

Hello!

I’m trying to build a hierarchical linear regression model in Pyro. I’ve managed to get it working with sequential plates but I’d like to use a vectorized implementation. My current model structure is:

def linear_regression_model(x_normal: torch.Tensor,
                            x_half_normal: torch.Tensor,
                            y: torch.Tensor = None,
                            group_labels: torch.Tensor = None) -> None:

    # Global hyperpriors.
    global_mu = pyro.sample(
        'global_mu',
        dist.Normal(0.0, 0.1).expand([x_normal.shape[1]]).to_event(1))
    global_sigma = pyro.sample(
        'global_sigma',
        dist.HalfNormal(scale=0.1).expand([x_normal.shape[1]]).to_event(1))

    num_groups = len(np.unique(group_labels))
    for group in range(num_groups):
        # Indices for each group within the training data.
        ind = np.where(group_labels == group)

        # Normal distribution priors.
        weight = pyro.sample(
            f'weight_{group}',
            dist.Normal(global_mu, global_sigma).expand([x_normal.shape[1]]).to_event(1))
        weighted_sum_normal = torch.matmul(x_normal[ind, :], weight.unsqueeze(-1)).squeeze(-1)
        # Half-Normal distribution priors.
        weight_half_normal = pyro.sample(
            f'weight_half_normal_{group}',
            dist.HalfNormal(scale=0.1).expand([x_half_normal.shape[1]]).to_event(1))
        weighted_sum_half_normal = torch.matmul(x_half_normal[ind, :], weight_half_normal.unsqueeze(-1)).squeeze(-1)
        # Bias Term Prior.
        bias = pyro.sample(f'bias_{group}', dist.Normal(0.0, 0.5).expand([1]).to_event(1))
        # Observation Noise Prior.
        sigma = pyro.sample(f'sigma_{group}', dist.HalfNormal(scale=0.5).expand([1]).to_event(1))
        # Expected Mean: A linear combination of the input features, weight matrices, and bias vector.
        mean = weighted_sum_normal + weighted_sum_half_normal + bias

        if y is not None:
            pyro.sample(f'obs_{group}', dist.Normal(mean, sigma).to_event(1), obs=y[ind])
        else:
            pyro.sample(f'obs_{group}', dist.Normal(mean, sigma).to_event(1), obs=None)

And the input data is created as:

import pandas as pd
from sklearn.preprocessing import LabelEncoder

group_encoder = LabelEncoder()
train_df['group_labels'] = group_encoder.fit_transform(train_df['group_names'].values)
group_labels = train_df['group_labels'].values

x_normal  # Tensor of shape (num_obs, num_features_with_normal_priors)
x_half_normal  # Tensor of shape (num_obs, num_features_with_half_normal_priors)
y  # Tensor of shape (num_obs,)

I’m not sure if there is a way to index the plate # using with pyro.plate('group', num_groups) as ind to then apply x_normal[ind, :] within each plate. I’ve also tried to “zero-pad a single big tensor and use poutine.mask to include only the real observations” as in this example but can’t get it to work with the multiple features per observation.

I’d greatly appreciate any feedback or recommendations. Thanks a lot!