IndexError with Predictive but not when fitting

There is something that I am still not understanding with regard to shapes in Pyro. Specifically, I am always running into errors when using Predictive that I am not getting when I fit the model. For example, a model that contains:

mu = pyro.sample("mu", Normal(loc=dtensor(mean_priors), scale=dtensor(10.)).expand([P]).to_event(1))

creates a variable of length P with the mean values in mean_priors. Further down in the model, this is used as an expected value

theta = mu[p_index]

which works just fine for the set of indices p_index. However, when I run Predictive after fitting:

ppc = pyro.infer.Predictive(
    model, 
    guide=guide, 
    num_samples=100, 
    return_sites=["pred_values"]
)
ppc_samples = ppc(*data[:-1], None)

It fails due to indexing:

     426 
     427     # Expected value
---> 428     theta = mu[p_index]       

IndexError: index 2038 is out of bounds for dimension 0 with size 1

Its not clear to me what is causing this, given that the inputs are exactly the same.

So, if I print out the shape of mu and run Predictive again, I get the following output:

torch.Size([5353])
torch.Size([5353])
torch.Size([5353])
torch.Size([1, 5353])

So this explains the error to some degree, but I’m not sure why the shape changes in the middle of the run. Do I need to squeeze here or something? Or change the shape of the input data?

It looks like I can work around this by adding ellipses all over the place where indexing happens. For example,

theta = mu[..., pitcher_index] 

Is this really the appropriate way to write Pyro code? It seems like something that users should not have to worry about. If Pyro changes the shape of my variables at runtime (here by adding a dim) then it should take care of making it work.

Based on the error message, I guess so. If p_index is a latent variable, then the following lines of code need to satisfy the assumption that p_index has batch dimensions. If it is just a scalar-tensor, then there’s likely a bug somewhere.

No, p_index is an observed tensor of indices for pulling out the appropriate value of mu:

tensor([  0,   1,   1,  ...,  67, 216, 216])

Oops my bad, sorry. If it is a 1D tensor and mu is derived from some latent variables, then indexing over mu needs to satisfy the assumption that mu has batch dimensions. So you are right that mu[..., p_index] is the recommended way to write Pyro code (at least until we can use the vectorized map or named tensor in PyTorch).

1 Like