Extra dim added to some variables when using Predictive

A model that is fit without error using SVI fails when generating samples with Predictive.

     32     # Add jitter to diagoal
     33     cov_beta = K(torch.arange(D).to(device)).contiguous()
---> 34     cov_beta.view(-1)[::D+1] += jitter + noise
     35 
     36     beta = pyro.sample('beta', MultivariateNormal((torch.zeros(D).to(device)) + y.mean(), 

RuntimeError: output with shape [792] doesn't match the broadcast shape [1, 792]
    Trace Shapes:   

Seems odd to change the shape of nodes automatically. Is there an easy fix for this, other than adding squeeze() calls everywhere? I’m at a loss as to why the model would run when being fit, but not when being called by Predictive. All I am trying to do is trace a deterministic variable.

Note that this is just a kernel for a single latent GP, so there should not be any need to call to_event or use a plate here.

Hi @fonnesbeck, I agree that adding those extra singleton dimensions is a bit odd. In this case, you can just reshape the right-hand side but this issue should be addressed or at least, make clear in the docs how to resolve the issue. (For the context, those extra dimensions are used to match the behavior when parallel=True.)

I came across the same problem in a much simpler context. I wanted to modify your intro to Bayesian linear regression example, so that the posterior mean of the predictive distribution is stored as a deterministic variable, instead of relying on the _RETURN magic incantation. But I get a mysterious extra dim of 1 added. (This also occurs is other examples of predictive, where I want to return stochastic parameters of the model (e.g., linear.weights), so it is not unique to deterministic sites.)

Specifically, here is the model (the only change is marked as ## NEW)

class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 10.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        mu = pyro.deterministic("mu", mean) ### NEW
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean

I run MCMC inference to get the parameter posterior :

pyro.set_rng_seed(1)
model = BayesianRegression(3, 1)
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=200)
mcmc.run(x_data, y_data)

Finally I compute the predictive posterior:

predictive = Predictive(model, mcmc.get_samples(), return_sites=("obs", "mu", "_RETURN"))
hmc_samples_pred = predictive(x_data)
print(hmc_samples_pred.keys())
print(hmc_samples_pred['obs'].shape)
print(hmc_samples_pred['mu'].shape)
print(hmc_samples_pred['_RETURN'].shape)

This yields the following (note the shape of mu is (S,1,N) instead of (S,N), for reasons that are not clear).

dict_keys(['obs', 'mu', '_RETURN'])
torch.Size([1000, 170])
torch.Size([1000, 1, 170]) ## very weird
torch.Size([1000, 170])

I see some shenanigans about adding an extra dim on line 72 of infer.predictive.py but don’t understand it…

1 Like

Yes, I’m not sure the aesthetics of parallelization is worth the tradeoff here; I think something more important was broken compared to what was fixed. Hacking the problem in my model just moves the issue to other places in the model. I basically cannot use Predictive with my model at all as it stands now.

In the absence of being able to use Predictive, is there any other way to extract deterministic quantities from a fitted model (e.g. via a poutine effect handler)?

Back then, we faced an issue of inconsistency between parallel=True and parallel=False. In this PR, I fixed that issue and also expected the confusion. I think we can make an FR and discuss more on what is a good API. :slight_smile: something odds would be a very nice reason for an FR.

2 Likes

Hi @fonnesbeck and @murphyk, thanks for kicking the tires on Pyro and NumPyro. The appearance of extra batch dimensions is a particularly unintuitive feature of Pyro’s programming model that can be quite frustrating at times, especially for new users.

Please see our tensor shape tutorial for an explanation of this behavior and general tips for writing code that is correct by construction. The TL;DR is that Pyro’s inference machinery expects to be able to add arbitrary batch dimensions to the left of the leftmost plate dimension in your model, and by following a few simple rules of thumb you can usually guarantee that your code preserves these dimensions and thus remains compatible with fancy algorithms like TraceEnum_ELBO and utilities like Predictive without sacrificing readability or performance.

@murphyk in your case those rules of thumb would suggest removing the .squeeze(-1) call on mean, placing the pyro.deterministic call inside the data plate, and passing event_dim=1 to deterministic.

@fonnesbeck it’s hard to say without more context, but it might be enough to rewrite line 34 so that it preserves any extra batch dimensions in jitter and noise. If you can provide runnable code, we can hopefully be more helpful. Alternatively, if you know you won’t ever care about vectorizing inference or prediction in this model, you can use trace directly as in _predictive_sequential.

I agree with @fehiepsi that a feature request issue would be a good place for more discussion. For what it’s worth, I think the current unintuitive behavior of Predictive is correct in the sense that it is consistent with the may-add-extra-batch-dimensions contract of other inference algorithms and vectorization in Pyro adds enough value that consistency is in users’ best interests, somewhat analogous to the way Jax requires users to write immutable code and use XLA control flow primitives despite the learning curve this imposes on new users. We could certainly do a better job of explaining this and anticipating user frustration throughout the documentation, though.

In the long term, the ideal solution is of course for Pyro and Numpyro to automatically transform model code to be broadcasting-aware (generalizing vmap in Jax) rather than forcing users to hand-write broadcastable code, which would eliminate whole classes of bugs like the ones in this thread and significantly improve overall user experience.

A working prototype of this fundamentally better approach is available as part of Funsor (see this gist for a specific example), but the timeline for fully integrating this prototype or some other version of the same transformations into Pyro and Numpyro is uncertain because of severely limited developer bandwidth (if any readers are interested in contributing, please reach out!).

Thanks for all of the insight @eb8680_2 and @fehiepsi. I think an even better FR would be to make deterministics first-class citizens, and have their values traced by default, just like stochastics, so that they show up in calls to quantiles, etc. If the user goes through the trouble of wrapping nodes in a deterministic they are indicating that the variables are of primary interest, and therefore should be traced. But yes, I agree that automating much of the shape handling would be ideal. Fundamentally, though, any model that can be fit without a shape error should also be able to predict without a shape error, so perhaps there are some test coverage issues too.

As for my particular problem, it was mainly confusing because both jitter and noise are scalars, so having shape issues was surprising.

I’m still working on debugging this without success. Here’ is some more detail regarding the shape problem:

     355     cov_beta = K(torch.arange(D).to(device)).contiguous()
---> 356     cov_beta.view(-1)[::D+1] += noise
     357 
     358     beta = pyro.sample('beta', MultivariateNormal((torch.zeros(D).to(device)) + y.mean(),

RuntimeError: output with shape [792] doesn't match the broadcast shape [1000, 792]
    Trace Shapes:            
     Param Sites:            
kern0.lengthscale            
   kern0.variance            
kern1.lengthscale            
   kern1.variance            
    Sample Sites:            
        s_mu dist 1000    1 |
            value 1000    1 |
    pitchers dist           |
            value      1726 |
          mu dist 1000 1726 |
            value 1000 1726 |
       gamma dist 1000    1 |
            value 1000    1 |
       noise dist 1000    1 |
            value 1000    1 |

So, the view on the covariance matrix is of shape [792], while the noise is a scalar. The 1000 is the number of predictive samples from:

predictive = pyro.infer.Predictive(bias_model, guide=guide, num_samples=1000, return_sites=['alpha'])

Pyro is clearly not doing the same thing to the view on cov_beta as it is with the noise. I have tried squeezing noise, and its hard to do much with a view on a matrix, correct?

OK, was able to get this to work by adding noise with a diagonal matrix, rather than with a view:

cov_beta + torch.eye(D)*noise

Is this expected behavior, or a bug?

Hi @fonnesbeck, if this is just a shape mismatch issue, could you provide cov_beta.shape and noise.shape? I believe there are some pytorch functions that help us add noise to the diagonal part of cov_beta, even when we have a collection of noises. Multiplying with eye also looks reasonable to me.

The sizes are torch.Size([792]) and torch.Size([]), respectively, which is why I’m surprised there is a problem.

It is interesting - the error said that 1000 is involved here: RuntimeError: output with shape [792] doesn't match the broadcast shape [1000, 792]. Probably the noise is scalar, but some of its attributes like grad have shape (1000,)?? I don’t actually know what’s going on. :frowning: It sounds like a tensor bug to me.

Thanks! So it seems putting mu inside the plate now works in the way I expected (ie., makes mu have size (S,N), where I draw S samples from N observations in x), since it tells pyro the first dimension is the batch dimension.

    def forward(self, x, y=None):
        N = x.shape[0]
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x) # (N,1)
        mean_vec = mean.squeeze(-1) # (N)
        with pyro.plate("data", x.shape[0]):   
          mu = pyro.deterministic("mu", mean_vec) # (N)       
          obs = pyro.sample("obs", dist.Normal(mean_vec, sigma), obs=y)
        return mean_vec
1 Like