Error in broadcasting during enumeration

Hi,

Assume that I have 45 subjects, each with 68 features. I want to assign each subjects to K =3 clusters by creating a categorical variable called assignments whose size equals to the number of subjects: 45.

In my code, xt_collection is a matrix with size [45, 68, 3] = [nsubs, nfeatures, class K]. It means for each of the 45 subjects with 68 features, there are three possible assignments. When there’s no enumeration, xt_collection[np.arange(45),:,assignments] can lead to an outcome matrix with size [45, 68], which means for subject i = 0,…,44, pick the corresponding 68 elements from a [68,3] matrix according to the [68,assignments[i]]. So in the toy example, I would expect that the estimation of assignments equal to torch.zeros(45) since my actual observations are the same as column 0 of the xt_collection matrix. (All the subjects in this case come from cluster 0)

Without enumeration, the sampling works fine. However, when I use enumeration, there’s an error in dimention. One possible solution I came out is to replace the nested pyro.plate with a nested for loop, however, this leads to a very slow inference.

Could you please give me any suggestions?

Thank you so much for your help!

import torch
import os
import pyro
import numpy as np
from pyro.optim import Adam
from pyro.infer.autoguide import AutoNormal,init_to_value
from pyro.ops.indexing import Vindex
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO,TraceGraph_ELBO, config_enumerate, infer_discrete
from pyro.ops.special import safe_log
from pyro.ops.tensor_utils import convolve
from pyro.util import warn_if_nan
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.1')


x0_obs = 0.01* torch.ones([45,68])
x1_obs = torch.ones([45,68])
x2_obs = 10*torch.ones([45,68])

@config_enumerate
def model(xobs):
    K = 3 # number of class to be signed
    sigma = pyro.sample("sigma", dist.HalfNormal(1))
    weights = pyro.sample('weights', dist.Dirichlet(torch.ones(K)/K))
    with pyro.plate('nsubs', 45):
        assignments = pyro.sample("assignments", dist.Categorical(weights))
    
    xt_collection = torch.stack([x0_obs,x1_obs,x2_obs],axis = 2) # torch.Size([45, 68, 3]) = [nsubs, nfeatures, class K]
    xt_keep = xt_collection[torch.arange(45),:,assignments]

    with pyro.plate('Nfeatures',68):
        with pyro.plate('Nsubs',45):
            pyro.sample('obs', dist.Normal(xt_keep, sigma), obs= torch.tensor(x0_obs))
global global_guide, svi
pyro.set_rng_seed(31)
pyro.clear_param_store()

optim = pyro.optim.Adam({'lr': 0.01})
global_guide = AutoNormal(poutine.block(model, hide=['assignments']))
elbo = TraceEnum_ELBO()
svi = SVI(model, global_guide, optim, loss=elbo)
losses = []

for i in range(1001 if not smoke_test else 2):
    loss = svi.step(x0_obs)
    losses.append(loss)
    if i % 10 == 0:
        print("ELBO at iter i = "+str(i),loss)
KeyError: -4

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Invalid tensor shape.
  Allowed dims: -3, -2, -1
  Actual shape: (3, 1, 45, 68)
  Try adding shape assertions for your model's sample values and distribution parameters.

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pyro/ops/packed.py in pack(value, dim_to_symbol)
     42                     ]
     43                 )
---> 44             ) from e
     45         value = value.squeeze()
     46         value._pyro_dims = dims

ValueError: Error while packing tensors at site 'obs':
  Invalid tensor shape.
  Allowed dims: -3, -2, -1
  Actual shape: (3, 1, 45, 68)
  Try adding shape assertions for your model's sample values and distribution parameters.
   Trace Shapes:              
    Param Sites:              
   Sample Sites:              
      sigma dist           |  
           value           |  
        log_prob           |  
    weights dist           | 3
           value           | 3
        log_prob           |  
assignments dist        45 |  
           value   3  1  1 |  
        log_prob   3  1 45 |  
        obs dist 3 1 45 68 |  
           value     45 68 |  
        log_prob 3 1 45 68 |  

Hi @Sky_X, you might try using Vindex as recommended in the tensor shapes tutorial:

- xt_keep = xt_collection[torch.arange(45),:,assignments]
+ xt_keep = Vindex(xt_collection)[..., assignments]

Also I’d recommend doing as much tensor maniuplation as possible outside of the model; that will speed up per iteration. E.g. you should torch.stack your obs tensors outside of the model to avoid per-step cost.

Thank you so much for your quick response, Fritzo!

When I tried

- xt_keep = xt_collection[torch.arange(45),:,assignments]
+ xt_keep = Vindex(xt_collection)[..., assignments]

With enumeration, I got

/usr/local/lib/python3.7/dist-packages/pyro/ops/indexing.py in vindex(tensor, args)
    195     args = tuple(args)
    196 
--> 197     return tensor[args]
    198 
    199 

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [45, 1], [68], [45]

And without enumeration, the size of xt_keep = xt_collection[..., assignments] is torch.Size([45, 68, 45]), which is not [45, 68] as what I expected.

So I also tried Vindex(xt_collection)[torch.arange(45),:,assignments].shape. Without enumeration, the shape of Vindex(xt_collection)[torch.arange(45),:,assignments] is now correct, which is [45, 68] . But with enumeration, I still got error:

KeyError: -4

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
ValueError: Invalid tensor shape.
  Allowed dims: -3, -2, -1
  Actual shape: (3, 1, 45, 68)

Could you please take a further look?

Thank you so much!

Hi @Sky_X, hmm, I seem to have gotten your model working in both the non-enumerated and enumerated cases, but it’s not pretty :thinking:

if assignments.dim() == 1:
    assignments = assignments[:, None]
xt_keep = Vindex(xt_collection)[...,assignments]

I hope this unblocks you. Please let me know if you find a cleaner solution :slightly_smiling_face:

That’s amazing! Thank you so much for solving this problem, Fritzo! I’ve been struggling on this for a long time!

May I ask a follow-up question? Suppose I still want to cluster the 45 subjects into K clusters, but now I want to add some constraints on the categorical variable called assignments: Some specific subjects must be assigned to the same cluster. For instance, subjects 0,1 should be bounded together, subjects 2,3,4 should be bounded together, so as to subjects 7,8… There are altogether 11 bounded aggregations for 45 subjects. That means although there are 45subjects, I’m actually clustering 11 aggregations since some subjects are bounded together.

One way to achieve this is to actually sample a categorial assignments variable which is a column vector containing 11 elements, and also generate a 45 * 11 transform matrix (containing 0 and 1). Then transform matrix (45,11) * assignments (11,1) gives me 45*1 “constrained” assignments for the 45 subjects, where subjects 0,1 have the same label, etc.

Once again, it can work without enumeration. However, when it comes to pyro enumeration, this idea seems not applicable, since it’s hard to do matrix * vector calculation since the shape of the vector called assignments changes during enumeration. Another problem is: in order to carry out such a calculation, the int variable should be transformed into a double data type.

Could you provide some suggestions on how to sample such a “constrained categorical vector” where some known elements must have the same label?

The modified code for my above idea of using transforming matrix was:

x0_obs = 0.01* torch.ones([45,68])
x1_obs = torch.ones([45,68])
x2_obs = 10*torch.ones([45,68])

import numpy as np
agg_indx = 2*[0]+3*[1]+2*[2]+3*[3]+5*[4]+5*[5]+5*[6]+5*[7]+5*[8]+5*[9]+5*[10]
nsubs = 45
naggs = 11
trans_mat = np.zeros((nsubs,naggs))
row = 0
for col in agg_indx:
    trans_mat[row,col] = 1
    row += 1
trans_mat = torch.tensor(trans_mat).double()

@config_enumerate
def model(xobs):
    K = 3 # number of class to be signed
    sigma = pyro.sample("sigma", dist.HalfNormal(1))
    weights = pyro.sample('weights', dist.Dirichlet(torch.ones(K)/K))
    with pyro.plate('naggregations', 11):
        assignments = pyro.sample("assignments", dist.Categorical(weights))
    assignments  = (trans_mat @ assignments.double()).long()
    xt_collection = torch.stack([x0_obs,x1_obs,x2_obs],axis = 2) # torch.Size([45, 68, 3]) = [nsubs, nfeatures, class K]
    if assignments.dim() == 1:
        assignments = assignments[:, None]
    xt_keep = Vindex(xt_collection)[...,assignments]

    with pyro.plate('Nfeatures',68):
        with pyro.plate('Nsubs',45):
            pyro.sample('obs', dist.Normal(xt_keep, sigma), obs= torch.tensor(x0_obs))

The error was due to the shape change in the enumeration:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x1 and 11x45)

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-983d273eedaa> in model(xobs)
     37     with pyro.plate('naggregations', 11):
     38         assignments = pyro.sample("assignments", dist.Categorical(weights))
---> 39     assignments  = (trans_mat @ assignments.double()).long()
     40     xt_collection = torch.stack([x0_obs,x1_obs,x2_obs],axis = 2) # torch.Size([45, 68, 3]) = [nsubs, nfeatures, class K]
     41     if assignments.dim() == 1:

RuntimeError: mat1 and mat2 shapes cannot be multiplied (3x1 and 11x45)
     Trace Shapes:           
      Param Sites:           
     Sample Sites:           
        sigma dist        |  
             value        |  
      weights dist        | 3
             value        | 3
naggregations dist        |  
             value     11 |  
  assignments dist     11 |  
             value 3 1  1 | 

Thank you so much!