I have a multivariate normal likelihood with a block diagonal covariance matrix, with blocks of varying dimension.

Roughly speaking, the distribution has dimension 1million, and the covariance matrix has ~1,500 blocks of non-zero elements ranging from dimension 10 to 1000. Ultimately, I am interested in inferring parameters related only to the mean.

Due to size, I am using SVI, and I’ve factored the model by block using the sequential plate.

I’ve put the blocked data into python lists of torch tensors: `mu`

, `sigmaL`

and `beta_hat`

. Each of which is a length `k`

list containing tensors corresponding to the block data.

My model + guide has the following form:

```
def blocked_model(mu, sigmaL, beta_hat):
k_blocks = len(mu)
for idx in pyro.plate("data", k_blocks, device=device):
p_k = mu[idx].shape[0]
beta_k = pyro.sample(
f"beta_{idx}",
dist.Normal(torch.zeros(p_k, device=device),
torch.ones(p_k, device=device))
)
pyro.sample(
f"beta_hat_{idx}",
dist.MultivariateNormal(
torch.matmul(mu[idx], beta_k), scale_tril=sigmaL[idx]
),
obs=beta_hat[idx],
)
def blocked_guide(mu, sigmaL, beta_hat):
k_blocks = len(beta_hat)
for k in range(k_blocks):
p_k = beta_hat[k].shape[0]
beta_loc = pyro.param(f"beta_{k}_loc",
torch.zeros(p_k, device=device))
beta_scale = pyro.param(f"beta_{k}_scale",
torch.ones(p_k, device=device),
constraint=dist.constraints.positive
)
pyro.sample(f"beta_{k}", dist.Normal(beta_loc, beta_scale))
```

And this appears to work.

I’m pretty new to pytorch and pyro. Is a better way to write down this model/guide? Is any easy performance left on the table?

Loading all the data into memory provided a speedup. Moving to a GPU provided a speedup. I’ve also considered doing some distributed/multi-threaded training since each block is independent, but I’m not sure how that would work exactly.

I’ve also considered porting to `numpyro`

, but there doesn’t appear to be an easy story about these kinds of jagged data structures there either.