Help - Initialised weights dims with Pyro Sample but when run model weights dims change

Hello,

I am working on a multimodel neural network with Bayesian Layers.
I have been stuck on this problem and have no idea how to fix it.

My model is:
class MMBNN(PyroModule):
def init(self, vocab_size, n_hidden, n_layers, n_out, num_embeddings, embedding_dim):

super(MMBNN,self).__init__()

# LSTM for the text overview
self.vocab_size, self.n_hidden, self.n_out, self.n_layers = vocab_size, n_hidden, n_out, n_layers
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, self.n_hidden, self.n_layers, dropout=0.3, batch_first=True)
#self.lstm2 = nn.LSTM(embedding_dim, self.n_hidden, self.n_layers, dropout=0.2, batch_first=True)
self.dropout = nn.Dropout(0.3)
#self.lstm_fc = nn.Linear(self.n_hidden, 128)
self.lstm_fc = PyroModule[nn.Linear](self.n_hidden,128) 
self.lstm_fc.weight = PyroSample(Normal(0., 1.).expand([128,self.n_hidden]).to_event(2))
self.lstm_fc.bias = PyroSample(Normal(0., 10.).expand([128]).to_event(1))
# self.sigmoid = nn.Sigmoid()

# CNN for the posters
self.conv1 = nn.Conv2d(3, 32, 3)
self.max_pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, 3)
self.max_pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(64, 128, 3)
self.max_pool3 = nn.MaxPool2d(2)
self.conv4 = nn.Conv2d(128, 128, 3)
self.max_pool4 = nn.MaxPool2d(2)
self.cnn_dropout = nn.Dropout(0.3)
#self.cnn_fc = nn.Linear(5*2*128, 128)
self.cnn_fc = PyroModule[nn.Linear](5*2*128, 128)
self.cnn_fc.weight = PyroSample(Normal(0., 1.).expand([128,5*2*128]).to_event(2))
self.cnn_fc.bias = PyroSample(Normal(0., 10.).expand([128]).to_event(1))    

self.combined_fc = PyroModule[nn.Linear](256, 128)
self.combined_fc.weight = PyroSample(Normal(0., 1.).expand([128,256]).to_event(2))
self.combined_fc.bias = PyroSample(Normal(0., 10.).expand([128]).to_event(1))   

#self.output_fc = nn.Linear(128, n_out)
self.output_fc = PyroModule[nn.Linear](128, n_out)
self.output_fc.weight = self.combined_fc.weight = PyroSample(Normal(0., 1.).expand([n_out,128]).to_event(2))
self.output_fc.bias = PyroSample(Normal(0., 10.).expand([n_out]).to_event(1))  
#self.softmax = nn.Softmax(2)

def forward(self, lstm_inp, cnn_inp, y=None):
batch_size = lstm_inp.size(0)
hidden = self.init_hidden(batch_size)
lstm_inp = lstm_inp.long()
#print(lstm_inp.size(), lstm_inp)
embeds = self.emb(lstm_inp)

#print(embeds.size(), embeds)
lstm_out, hidden = self.lstm(embeds, hidden)

lstm_out = self.lstm_fc(lstm_out[:,-1])

x = F.relu(self.conv1(cnn_inp))
x = self.max_pool1(x)
#print(x.size())
x = F.relu(self.conv2(x))
x = self.max_pool2(x)
#print(x.size())
x = F.relu(self.conv3(x))
x = self.max_pool3(x)
#print(x.size())
x = F.relu(self.conv4(x))
x = self.max_pool4(x)
#print(x.size())
x = x.view(-1, 5*2*128)
#print(x.size())

cnn_out = self.cnn_fc(x)

print(self.combined_fc.weight.size(),self.combined_fc.bias.size())

combined_inp = torch.cat((cnn_out, lstm_out), 1) 
#print(combined_inp.size())
x_comb = self.combined_fc(combined_inp)
x_comb = F.relu(x_comb)

z = self.output_fc(x_comb)

out = z

with pyro.plate("data2",batch_size):
    obs = pyro.sample("obs", Categorical(logits = out),obs = y)

return out

def init_hidden(self, batch_size):
weight = next(self.parameters()).data
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),
weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
return hidden

As you can see I set:
self.combined_fc = PyroModule[nn.Linear](256, 128)
self.combined_fc.weight = PyroSample(Normal(0., 1.).expand([128,256]).to_event(2))
self.combined_fc.bias = PyroSample(Normal(0., 10.).expand([128]).to_event(1))

Which means my weights should have dimensions [128,256]
However when I run the model I get the following error from this
RuntimeError: The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 1. Target sizes: [100, 2]. Tensor sizes: [128]
Trace Shapes:
Param Sites:
Sample Sites:
lstm_fc.weight dist | 128 64
value | 128 64
lstm_fc.bias dist | 128
value | 128
cnn_fc.weight dist | 128 1280
value | 128 1280
cnn_fc.bias dist | 128
value | 128
combined_fc.weight dist | 2 128
value | 2 128
combined_fc.bias dist | 128
value | 128
Trace Shapes:
Param Sites:
Sample Sites:

The combined_fc.weight has now changed to [2,128].
My question is how? and how can I fix this?

The following is the full trace back:
Traceback (most recent call last):

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 174, in call
ret = self.fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “”, line 77, in forward
x_comb = self.combined_fc(combined_inp)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\linear.py”, line 96, in forward
return F.linear(input, self.weight, self.bias)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\functional.py”, line 1847, in linear
return torch._C._nn.linear(input, weight, bias)

RuntimeError: The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 1. Target sizes: [100, 2]. Tensor sizes: [128]

The above exception was the direct cause of the following exception:

Traceback (most recent call last):

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 174, in call
ret = self.fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\autoguide\guides.py”, line 408, in forward
self._setup_prototype(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\autoguide\guides.py”, line 378, in _setup_prototype
super()._setup_prototype(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\autoguide\guides.py”, line 171, in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 198, in get_trace
self(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 180, in call
raise exc from e

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 174, in call
ret = self.fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “”, line 77, in forward
x_comb = self.combined_fc(combined_inp)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\linear.py”, line 96, in forward
return F.linear(input, self.weight, self.bias)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\functional.py”, line 1847, in linear
return torch._C._nn.linear(input, weight, bias)

RuntimeError: The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 1. Target sizes: [100, 2]. Tensor sizes: [128]
Trace Shapes:
Param Sites:
Sample Sites:
lstm_fc.weight dist | 128 64
value | 128 64
lstm_fc.bias dist | 128
value | 128
cnn_fc.weight dist | 128 1280
value | 128 1280
cnn_fc.bias dist | 128
value | 128
combined_fc.weight dist | 2 128
value | 2 128
combined_fc.bias dist | 128
value | 128

The above exception was the direct cause of the following exception:

Traceback (most recent call last):

File “”, line 79, in
epoch_loss += svi.step(text_tensor,img_tensor, labels)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\svi.py”, line 145, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\trace_elbo.py”, line 140, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\elbo.py”, line 186, in _get_traces
yield self._get_trace(model, guide, args, kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\trace_elbo.py”, line 57, in _get_trace
model_trace, guide_trace = get_importance_trace(

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\enum.py”, line 52, in get_importance_trace
guide_trace = poutine.trace(guide, graph_type=graph_type).get_trace(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 198, in get_trace
self(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 180, in call
raise exc from e

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 174, in call
ret = self.fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\autoguide\guides.py”, line 408, in forward
self._setup_prototype(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\autoguide\guides.py”, line 378, in _setup_prototype
super()._setup_prototype(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\infer\autoguide\guides.py”, line 171, in _setup_prototype
self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 198, in get_trace
self(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 180, in call
raise exc from e

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\trace_messenger.py”, line 174, in call
ret = self.fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\poutine\messenger.py”, line 12, in _context_wrap
return fn(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “”, line 77, in forward
x_comb = self.combined_fc(combined_inp)

File “D:\Programs\Anaconda\lib\site-packages\pyro\nn\module.py”, line 426, in call
return super().call(*args, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\module.py”, line 1051, in _call_impl
return forward_call(*input, **kwargs)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\modules\linear.py”, line 96, in forward
return F.linear(input, self.weight, self.bias)

File “D:\Programs\Anaconda\lib\site-packages\torch\nn\functional.py”, line 1847, in linear
return torch._C._nn.linear(input, weight, bias)

RuntimeError: The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 1. Target sizes: [100, 2]. Tensor sizes: [128]