Hi folks,
I am trying to fit the following model (with reprex / toy data) using NUTS:
def model(s, m, R, b):
p = se.shape[0]
a = pyro.sample('a', dist.LogNormal(0, 1))
pi = pyro.sample('pi', dist.Beta(1, 1))
scale = torch.tensor([a, 1e-13])
with pyro.plate('assign', p):
z = pyro.sample('z', dist.Bernoulli(probs=pi))
scale_z = scale[z.long()]
bp = pyro.sample('bp', dist.Normal(0., scale_z))
print(f"beta shape: {bp.shape}")
print(f"z-score shape: {(bp / s).T.shape}")
mean = torch.matmul(R, (bp / s).T)
print(f"mean shape: {mean.shape}")
pyro.sample('obs', dist.MultivariateNormal(mean, scale_tril=m), obs=(b/s))
def run_inference(s, m, R, b, num_samples=1000, warmup_steps=200):
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
mcmc.run(s, m, R, b)
return mcmc.get_samples()
s = torch.tensor([0.1, 0.2, 0.3])
m = torch.eye(3)
R = torch.eye(3)
b = torch.tensor([0.5, -0.2, 0.1])
samples = run_inference(s, m, R, b)
Unfortunately, I am getting this error (abridged version of output), along with these shapes:
bp shape: torch.Size([3])
z-score shape: torch.Size([3])
mean shape: torch.Size([3])
bp shape: torch.Size([2, 3])
z-score shape: torch.Size([3, 2])
mean shape: torch.Size([3, 2])
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1
As I am new to NUTS, I am not sure how to fix this as the dimension of the key players changes at different points in the NUTS process. Any advice would be appreciated, thank you!