A Subtle shape mismatch between deep layers


#1

I tried to follow Bayesian Regression tutorial architecture to implement a more complex LSTM model.
In the model, I use two LSTM layers as encoder, with hidden units 128 and 64 separately. And I use pyro.random_module to lift the weights to stochastic variables. I use independent(1) in the code, like
priors_dist[layer_name] = pyro.distributions.Normal(weights_loc, weights_scale).independent(1).

However running the code I face RuntimeError: The size of tensor a (512) must match the size of tensor b (256) at non-singleton dimension 0.
I found the error occurred at class Trace_ELBO --> def _compute_log_r(model_trace, guide_trace)–>log_r.add((stacks[name], log_r_term.detach())).
The reason is that at first layer the log_r_term size is (512, ), but second layer the log_r_term size change to (256, ), and in the log_r.add function it will face shape mismatch error.

It seems that I have to use independent(2) to eliminate the log_r_term shape mismatch. Is there some better ways to solve the problem?


#2

I think that using .independent(2) is right because your weight’s dim is 2.


#3

Thanks, @fehiepsi! I did it the same way i.e. when the weights tensor is 2D I use .independent(2), and when the tensor is 1D I use .independent(1). The model could work.
However, th elbo loss was pretty high which was about 4 million. I had to modify the observe normal distribution’s scale from 0.2 to 1.5, and the loss reduced to about 70,000. Then it was never less than 50,000 during SVI steps. I thought it may be the possible large KL(q(Z)|p(Z)) term to induce the large loss as there are about 0.3 million parameter distributions. But if I use traditional deep layers with dropout and L2 regularize, which is identical to the VI process in theory, the MSE loss wouldn’t be too high.
Is someone has same experience?


#4

I usually observed high ELBO losses in my experiments. ^^


#5

The ELBO loss term returned is the sum for the entire mini-batch, not the mean, in case that helps explain why you are seeing large loss values. See a discussion here on a similar issue.