Simple shapes confusion with NUTS

Hi folks,

I am trying to fit the following model (with reprex / toy data) using NUTS:

def model(s, m, R, b):
    p = se.shape[0]    
    a = pyro.sample('a', dist.LogNormal(0, 1))    
    pi = pyro.sample('pi', dist.Beta(1, 1))    
    scale = torch.tensor([a, 1e-13])    
    with pyro.plate('assign', p):
        z = pyro.sample('z', dist.Bernoulli(probs=pi))        
        scale_z = scale[z.long()]        
        bp = pyro.sample('bp', dist.Normal(0., scale_z))
    print(f"beta shape: {bp.shape}")
    print(f"z-score shape: {(bp / s).T.shape}")
    mean = torch.matmul(R, (bp / s).T)
    print(f"mean shape: {mean.shape}")
    pyro.sample('obs', dist.MultivariateNormal(mean, scale_tril=m), obs=(b/s))

def run_inference(s, m, R, b, num_samples=1000, warmup_steps=200):
    nuts_kernel = NUTS(model)
    mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
    mcmc.run(s, m, R, b)
    return mcmc.get_samples()

s = torch.tensor([0.1, 0.2, 0.3])
m = torch.eye(3)
R = torch.eye(3)
b = torch.tensor([0.5, -0.2, 0.1])

samples = run_inference(s, m, R, b)

Unfortunately, I am getting this error (abridged version of output), along with these shapes:

bp shape: torch.Size([3])
z-score shape: torch.Size([3])
mean shape: torch.Size([3])
bp shape: torch.Size([2, 3])
z-score shape: torch.Size([3, 2])
mean shape: torch.Size([3, 2])
RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 1

As I am new to NUTS, I am not sure how to fix this as the dimension of the key players changes at different points in the NUTS process. Any advice would be appreciated, thank you!