I have a multivariate normal likelihood with a block diagonal covariance matrix, with blocks of varying dimension.
Roughly speaking, the distribution has dimension 1million, and the covariance matrix has ~1,500 blocks of non-zero elements ranging from dimension 10 to 1000. Ultimately, I am interested in inferring parameters related only to the mean.
Due to size, I am using SVI, and I’ve factored the model by block using the sequential plate.
I’ve put the blocked data into python lists of torch tensors: mu
, sigmaL
and beta_hat
. Each of which is a length k
list containing tensors corresponding to the block data.
My model + guide has the following form:
def blocked_model(mu, sigmaL, beta_hat):
k_blocks = len(mu)
for idx in pyro.plate("data", k_blocks, device=device):
p_k = mu[idx].shape[0]
beta_k = pyro.sample(
f"beta_{idx}",
dist.Normal(torch.zeros(p_k, device=device),
torch.ones(p_k, device=device))
)
pyro.sample(
f"beta_hat_{idx}",
dist.MultivariateNormal(
torch.matmul(mu[idx], beta_k), scale_tril=sigmaL[idx]
),
obs=beta_hat[idx],
)
def blocked_guide(mu, sigmaL, beta_hat):
k_blocks = len(beta_hat)
for k in range(k_blocks):
p_k = beta_hat[k].shape[0]
beta_loc = pyro.param(f"beta_{k}_loc",
torch.zeros(p_k, device=device))
beta_scale = pyro.param(f"beta_{k}_scale",
torch.ones(p_k, device=device),
constraint=dist.constraints.positive
)
pyro.sample(f"beta_{k}", dist.Normal(beta_loc, beta_scale))
And this appears to work.
I’m pretty new to pytorch and pyro. Is a better way to write down this model/guide? Is any easy performance left on the table?
Loading all the data into memory provided a speedup. Moving to a GPU provided a speedup. I’ve also considered doing some distributed/multi-threaded training since each block is independent, but I’m not sure how that would work exactly.
I’ve also considered porting to numpyro
, but there doesn’t appear to be an easy story about these kinds of jagged data structures there either.