Reshaping Pyro.dist instance

Hi,

I’m now trying to implement Bayesian linear regression and following tutorial in official page.

I’d like to create a matrix whose each columns are sharing common generating distribution, i.e.
W = [w1 | w2 | … | wK] and wi ~ Normal(0, Vi)

from pyro.infer import MCMC, NUTS
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features, Sigma):
    super().__init__()
    sigma = pyro.param("sigma", torch.ones(out_features), constraint = constraints.positive)
    self.linear = PyroModule[nn.Linear](in_features, out_features)
    self.linear.weight = PyroSample(dist.Normal(torch.zeros(out_features), sigma).expand([out_features, in_features]).to_event(2))
    self.linear.bias = None
    self.Sigma = Sigma

def forward(self, x, y=None):

    mean = self.linear(x).squeeze(-1)
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.MultivariateNormal(mean, self.Sigma), obs=y)
    return mean

model = BayesianRegression(in_features, out_features, Sigma)

nuts_kernel = NUTS(model)

mcmc = MCMC(nuts_kernel, num_samples=80, warmup_steps=20)
mcmc.run(self.X.T, self.mu)

hmc_samples = {k: v.detach().cpu().numpy() for k, v in mcmc.get_samples().items()}

Here, the problem is that when expanding self.linear.weight to the target shape, it makes error as follows

RuntimeError: The expanded size of the tensor (in_features) must match the existing size (out_features) at non-singleton dimension 1.  Target sizes: [out_features, in_features].  Tensor sizes: [out_features]

I want to ask how to address this problem.
Thank you in advance.

the problem is that nn.Linear doesn’t expect expanded samples in that way. so the solution is to not use nn.Linear. i.e. do the matrix multiplication yourself without invoking nn.Linear

Okay, Thank you! I fixed the problem :grinning:

class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features, Sigma):
    super().__init__()
    sigma = pyro.param("sigma", torch.ones(out_features), constraint = constraints.positive)
    self.weight = PyroSample(dist.Normal(0, sigma).expand_by([in_features]).to_event(2))
    self.Sigma = Sigma

def forward(self, x, y=None):
    mean = x.matmul(self.weight)
    with pyro.plate("data", x.shape[0]):
        obs = pyro.sample("obs", dist.MultivariateNormal(mean, self.Sigma), obs=y)
    return mean