Pyro model that is fit without error fails during predict

I have a Pyro model that I have fit using SVI that, among other components, includes the following:

    with pyro.plate('year_effect', len(seasons)-1):
        gamma = pyro.sample('gamma', Normal(loc=dtensor(0.), scale=dtensor(10.)))
    _gamma =[dtensor([0.]), gamma])

(dtensor is just my own convenience function for creating a tensor on the CUDA device)

Though the model is fit without error, when I try to sample from the model via pyro.infer.Predictive, I get a failure at this point in the model.

RuntimeError                              Traceback (most recent call last)
<ipython-input-26-ba8d4e5a9b46> in ra_model(age_idx, pitcher_idx, season_idx, stuff_data, command_data, dN_data, n_data, r_data)
     72     with pyro.plate('year_effect', len(seasons)-1):
     73         gamma = pyro.sample('gamma', Normal(loc=dtensor(0.), scale=dtensor(10.)))
---> 74     _gamma =[dtensor([0.]), gamma])

RuntimeError: Tensors must have same number of dimensions: got 1 and 2

Is there any obvious reason why a dimension would be added to gamma after it is fit? This is clearly a 1-dimensional tensor as specified in the plate.

I started using Pyro very recently. When I hit a similar problem, I realised this was because I was running my code in Jupyter notebook and did not clear the param store.

From the documentation

Clears the global ParamStoreDict.

This is especially useful if you’re working in a REPL. We recommend calling this before each training loop (to avoid leaking parameters from past models), and before each unit test (to avoid leaking parameters across tests).

The following example also calls clear_param_store() in cell 14.

Ofcourse, you may already be doing this or not running in a REPL, so may not apply to you. I shared as I was making such a mistake when I modified my model in the REPL.

Unfortunately, this does not fix the problem. It’s not clear to me why there are shape issues with Predictive and not during fitting. If I remove that particular line of code altogether, it simply shifts to another line, with a similar error:

     17     f_aging = pyro.deterministic(
---> 18         "f_aging", (cov_age + torch.eye(A, device=device) * jitter).cholesky() @ f_tilde_age
     19     )

RuntimeError: mat1 and mat2 shapes cannot be multiplied (26x26 and 1x26)
   Trace Shapes:       
    Param Sites:       
   Sample Sites:       
       ages dist      |
           value   26 |
f_tilde_age dist   26 |
           value 1 26 |

Again, its not clear why a dimension is being added to one component of this operation, and not the other.

Adding a squeeze call to f_tilde_age makes it run, though its not clear why this is necessary for prediction.