Advanced plate usage question with non-tensor inputs

Hi everyone,

thanks for the fantastic library.

I’m working with a model for hierarchical forecasting that looks like this:

def diag_model(X_train_list, X_test_list, y_train, target_cols, hierarchy_idx=None):
    
    mu_train = []
    mu_test = []
    sigma_full = []     
    
    for i, col_name in enumerate(target_cols):
        
        if y_train is None:
            y_obs = None
        else:
            y_obs = y_train[:, i]
        
        X_train = X_train_list[i]
        X_test = X_test_list[i]
        n = X_train.shape[1]
        
        beta = pyro.sample(f'beta_{col_name}', dist.Normal(loc=torch.zeros(n), scale=torch.ones(n)))
        sigma = pyro.sample(f'sigma_{col_name}', dist.HalfCauchy(scale=torch.ones(1)))
        
        _mu_train = X_train @ beta
        _mu_test = X_test @ beta
        
        train_obs = pyro.sample(f"train_obs_{col_name}",
                                dist.Normal(loc=_mu_train, scale=sigma), obs=y_obs)
        
        mu_train.append(_mu_train)
        mu_test.append(_mu_test)
        sigma_full.append(sigma)
                                  
    mu_train = torch.stack(mu_train, dim=1)
    mu_test = torch.stack(mu_test, dim=1)
    sigma_diag = torch.diag(torch.cat(sigma_full))

    ################################################################################################
    # Reconciliation:
    ################################################################################################ 

    for i, (parent_idx, childrens_idx) in enumerate(hierarchy_idx):
        parent_idx, childrens_idx = hierarchy_idx[i]

        rec_diff = mu_test[:, parent_idx] - torch.sum(mu_test[:, childrens_idx], dim=1)

        sigma_subset_idx = [parent_idx] + childrens_idx
        sigma_subset = sigma_diag[sigma_subset_idx, :]
        sigma_subset = sigma_diag[:, sigma_subset_idx]

        var_rec = torch.sum(sigma_subset) - 2 * torch.sum(sigma_subset[1:, 0]) - 2 * torch.sum(sigma_subset[0, 1:])
        sigma_rec = torch.sqrt(var_rec)

        rec_obs = pyro.sample(f"reconciliation_obs_{i}", dist.Normal(loc=rec_diff, scale=sigma_rec),
                              obs=torch.zeros_like(rec_diff))

I wanted to try to re-write the model with plate syntax, but it’s slightly unclear to me how to go about it. The initial observations (f"train_obs_{col_name}") are indipendent from each other, but, the reconciliation observations depend on all the train obs having been computed.

Further, I cannot store all of my model inputs in tensors, because each of the train_obs uses different covariates, and, as such X_train_list is a list of multi-dimensional tensors of different sizes.