Batch shape when calling pyro.deterministic

Hi there! I have a small question about how batched variables are dealt with when calling pyro.deterministic. For example, I present the toy example, which samples a batch of variance and correlation, and calculates the sampled covariance. I believe this example is reproduceable:

@config_enumerate 
def toy_model(batch_size=100):
    with pyro.plate("component", batch_size):
        # component of prior for covariance
        theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+1)).to_event(1))
        omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
        Omega = pyro.deterministic("Omega", torch.bmm(theta.sqrt().diag_embed(), omega))

trace = poutine.trace(toy_model).get_trace()
print(trace.format_shapes())

Running the code to check the shapes of the variable, we have

 Trace Shapes:              
  Param Sites:              
 Sample Sites:              
component dist     |        
         value 100 |        
       nu dist 100 |        
         value 100 |        
    theta dist 100 |   2    
         value 100 |   2    
    omega dist 100 |   2 2  
         value 100 |   2 2  
    Omega dist 100 | 100 2 2
         value     | 100 2 2

I was expecting the shape of Omega will be the same as omega. However, there seems to be a duplicate of batch size. So I was wondering, was I not using the deterministic primitive correctly? How should I deal with this? (Currently I’m using a Normal distribution with very small variance, which seems to be immune to the problem above, despite some small noise is introduced.) Thanks for any advice!

Hi @Evan. Can you check what is the shape of torch.bmm(theta.sqrt().diag_embed(), omega)? And what do you expect it to be?

I checked that with print, which showed it is of size torch.Size([100, 2, 2]) (just as expected).

I thought perhaps pyro.deterministic() is treating this tensor as one event, and therefore adding another dim with size of batch_size? However, I haven’t figured out how to solve this yet.

(By the way, this also happens to other examples, likel when I’m using pyro.deterministic to record a set of variables which are linked to a set of values through discrete index. Also got an extra dim)

It looks like pyro.deterministic sets event_dim to value.ndim by default. What if you set event_dim=2?

Oh I see! It works!

 Sample Sites:          
component dist     |    
         value 100 |    
    theta dist 100 | 2  
         value 100 | 2  
    omega dist 100 | 2 2
         value 100 | 2 2
    Omega dist 100 | 2 2
         value 100 | 2 2

Thanks a lot!
(I think I should be more careful reading the document lol :rofl: )

1 Like

Oops, sorry but, when I tried out pyro.deterministic() with indexed values, I still got into troubles at inference time:

w = pyro.deterministic("w", pc[u], event_dim=1) 

It was of shape tensor.shape([50,2]), but later during enumeration it turns out to be torch.Size([9, 1, 2]) and caused IndexError. I also tried

w = pyro.sample("w", MultivariateNormal(pc[u], noise_ub * torch.eye(d)))

which is simply a noisy version of the program above (, I thought), with the desired values being the mean of a MultivariateNormal. This alternative seems to work. But I’m confused about the difference: what’s causing the difference between deterministic and sample, and how should I handle this?

Can you show your code and the full error message?

Sure! (Thanks so much for your patience~) Here’s my model:

T = 50
T_pc = 9
def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

@config_enumerate
def model(data=None, gamma=0.1, alpha=0.1, noise_ub=0.001): 
    alpha_mu = pyro.param("alpha_mu", lambda: Gamma(1, 1).sample([1]), constraint=constraints.positive)
    alpha_w = pyro.param("alpha_w", lambda: Gamma(1, 1).sample([1]), constraint=constraints.positive)
    tau = pyro.param("tau", lambda: Gamma(1, 1).sample([1]), constraint=constraints.positive)

    with pyro.plate("sticks", T-1):
        beta = pyro.sample("beta", Beta(1, gamma))
    with pyro.plate("PC_sticks", T_pc-1):
        beta_ = pyro.sample("beta_", Beta(1, alpha))
    with pyro.plate("PCs", T_pc):
        pc = pyro.sample("pc", Normal(torch.zeros(d), 1/alpha_w.unsqueeze(-1)).to_event(1))

    with pyro.plate("component", T) as idx:
        u = pyro.sample("u", Categorical(mix_weights(beta_)), infer={'enumerate': 'parallel'})
        mu = pyro.sample("mu", Normal(torch.zeros(d), 1/alpha_mu.unsqueeze(-1)).to_event(1))
        w = pyro.deterministic("w", pc[u], event_dim=1)

    with pyro.plate("data", N) as idx:
        z = pyro.sample("z", Categorical(mix_weights(beta)), infer={'enumerate': 'parallel'}) 
        pyro.sample(
            "obs", 
            MultivariateNormal(
                mu[z], 
                precision_matrix=tau * torch.eye(d) + Vindex(w)[z].unsqueeze(-1)*Vindex(w)[z].unsqueeze(-2)),
            obs=data)

It is a CRP mixture model, whose covariance is modeled with the PPCA framework. In other words, its covariance is approximated with isotropic noise plus a low-rank matrix (here is simply rank 1, with only one vector, with normal prior). The error is like

     25     pyro.sample(
     26         "obs", 
     27         MultivariateNormal(
     28             mu[z], 
---> 29             precision_matrix=tau * torch.eye(d) + Vindex(w)[z].unsqueeze(-1)*Vindex(w)[z].unsqueeze(-2)),
     30         obs=data)
IndexError: index 9 is out of bounds for dimension 0 with size 9

And I found the shape issue with w mentioned before: it changes when using svi, with extra enum dimensions.

If you print out the shapes of w and z you will probably find out why you get IndexError. Looks like there is lots of math in your code, just be careful to make sure that shapes are correct and work with enumeration.

Yes, I just don’t understand how enumeration is handled in downstream sites, as in here MultivariateNormal seems to work well with enumerated loc parameter, while deterministic seems to simply take the enum shape as batch shape.

What is the shape of w when you use MultivariateNormal (with enumeration)?

deterministic should just return the value of pc[u].

Can you clarify this more?

I used print to check the shapes:

  • when I directly print pc[u].shape, I got torch.Size([9, 1, 2])
  • When using MultivariateNormal with pc[u] being its loc parameter, I got a torch.Size([50, 2]) tensor, with the shape at dim 0 consistent with the claim of the plate.
  • When using deterministic, I find the shape of the variable identical to that of pc[u], being of shape torch.Size([9, 1, 2])

Setting event_dim=1 helps recognize the last dimension as event_shape, while I don’t know how to properly handle the other dimensions.

I think the difference in the shape must be due to the fact that when you use pyro.sample the value of w is sampled by the guide (which I suppose has the shape of torch.Size([50, 2])) and then replayed into the model. MultivariateNormal(pc[u], ...) distribution in the model is only used to calculate the log density. If you print out the trace shapes you should see that w has dist shape of MultivariateNormal(pc[u], ...) (something like (9, 1, 2)) and value shape of (50, 2) (note that this value is sampled by the guide), and log_prob shape which is the broadcast of the two.

Oh :hushed: thanks! I wasn’t fully aware of these differences. I think I’m getting an intuitive understanding now.
Just one more question about this: how can I perform the similar sample-then-replay process with deterministic sites? For now, I’m using a noisy version of this expression (which is somewhat okay, at least it runs without errors, but not quite elegant I think)

w = pyro.sample("w", MultivariateNormal(pc[u], noise_ub * torch.eye(d)))

(I’ve tried this out, and this will not cause much trouble to my model: it can eventually converge to a good posterior. But I thought using deterministic will be more loyal to my original claim about the model.)

Maybe you can expand (repeat) pc[u] along dim=-2 so that it has the shape of (9, 50, 2). Then make sure when you use Vindex(w)[z] you are vindexing along that dim (since it is not the first dim anymore).

You can also just forget that you are using deterministic since it is not affecting inference in any way.