Batch cumputation of log_prob for different conditions

Hello, I Have an instance dist.ConditionalTransformedDistribution that I have conditioned on a batch of conditions, specifically : [batchx32] and for each of these conditions I have a list of points I would like to evaluate the log_prob for, for examples : [batchx2000x6] (2000 points of 6 features). Is there a way to do this in one go? The conditioning seems to work but then I get errors. If I try to only condition one at a time and then give [2000x6] it also works. If not, would looping over the batch dimensions be the best course of action? Thanks in advance.

hello,

unfortunately it’s hard to help with shape errors if a code snippet is not provided. please provide one so that forum members are more likely to be able to help you.

Yes my bad, just thought I might be missing something really obvious. Error is on the computation of the log_prob.

flow_dist = dist.ConditionalTransformedDistribution(base_dist, transformations)

n_epochs = 3000
early_stop_margin=0.01
optimizer = torch.optim.AdamW(parameters, lr=5e-3) 

early_stop = Early_stop(patience=patience,min_perc_improvement=torch.tensor(early_stop_margin))

for epoch in range(n_epochs):
   for batch in tqdm(dataloader):
        extract_0,enumeration_0,extract_1,enumeration_1 = batch
        encodings = Pointnet2(extract_0[:,3:],extract_0[:,:3],enumeration_0)
        loss = -flow_dist.condition(encodings).log_prob(extract_1).mean()

Error:

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
The size of tensor a (2000) must match the size of tensor b (3) at non-singleton dimension 1
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\torch\tensor.py", line 27, in wrapped
return f(*args, **kwargs)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\pyro\distributions\transforms\spline.py", line 34, in _searchsorted
values[..., None] >= sorted_sequence,
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\pyro\distributions\transforms\spline.py", line 146, in _monotonic_rational_spline
bin_idx = _searchsorted(cumheights + eps if inverse else cumwidths + eps, inputs).unsqueeze(-1)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\pyro\distributions\transforms\spline.py", line 313, in spline_op
y, log_detJ = _monotonic_rational_spline(x, w, h, d, l, bound=self.bound, **kwargs)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\pyro\distributions\transforms\spline.py", line 295, in _inverse
x, log_detJ = self.spline_op(y, inverse=True)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\torch\distributions\transforms.py", line 151, in _inv_call
x = self._inverse(y)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\torch\distributions\transforms.py", line 219, in __call__
return self._inv._inv_call(x)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\site-packages\torch\distributions\transformed_distribution.py", line 111, in log_prob
x = transform.inv(y)
  File "C:\code\flow_change\conditional_flow_compare.py", line 98, in <module>
loss = -flow_dist.condition(encodings).log_prob(extract_1).mean()
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\runpy.py", line 87, in _run_code
exec(code, run_globals)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\runpy.py", line 97, in _run_module_code
_run_code(code, mod_globals, init_globals,
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\runpy.py", line 265, in run_path
return _run_module_code(code, init_globals, run_name,
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\runpy.py", line 87, in _run_code
exec(code, run_globals)
  File "C:\Users\samme\Anaconda3\envs\flow_change\Lib\runpy.py", line 194, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,

If I try this it works:

 for batch_ind in range(extract_1.shape[0]):
                encoding = encodings[batch_ind,:]
                extract_1_points = extract_1[batch_ind,...]
                loss += -flow_dist.condition(encoding).log_prob(extract_1_points).mean()

From what I can see I can only batch one point per condition or batch many points for one condition, would like to know if there is a way to go about this that is more efficient than the last snippet.

can you please expand your code snippet to include important details? in particular:

  • full constructor for flow_dist
  • shapes of encodings, extract_1, …
input_dim = 6

base_dist = dist.Normal(torch.zeros(input_dim).to(device), torch.ones(input_dim).to(device))

count_bins = 16

context_dim= 32

patience = 50

n_layers = 20

Pointnet2 = Pointnet2(feature_dim=input_dim-3,out_dim=context_dim).to(device)

permutations = [torch.randperm(input_dim) for x in range(n_layers-1)]

class conditional_spline_flow:

    def __init__(self,input_dim,context_dim,permutations,count_bins,device):

        self.transformations = []

        self.parameters =[]

        

        for i in range(len(permutations)+1):

            hidden_dims = [128,128]

            spline = T.conditional_spline(input_dim,context_dim,hidden_dims=hidden_dims,count_bins=count_bins,bound=1.0)

            spline = spline.to(device)

            self.parameters += spline.parameters()

            self.transformations.append(spline)

            if i<len(permutations): #Not try to add to the end

                self.transformations.append(T.permute(input_dim,torch.LongTensor(permutations[i]).to(device),dim=-1))

    def save(self,path):

        torch.save(self,path)

conditional_flow_layers = conditional_spline_flow(input_dim,context_dim,permutations,count_bins,device)

parameters = conditional_flow_layers.parameters

transformations = conditional_flow_layers.transformations

flow_dist = dist.ConditionalTransformedDistribution(base_dist, transformations)

n_epochs = 3000

early_stop_margin=0.01

optimizer = torch.optim.AdamW(parameters, lr=5e-3) 

early_stop = Early_stop(patience=patience,min_perc_improvement=torch.tensor(early_stop_margin))

for epoch in range(n_epochs):

   for batch in tqdm(dataloader):

        extract_0,enumeration_0,extract_1,enumeration_1 = batch

        encodings = Pointnet2(extract_0[:,3:],extract_0[:,:3],enumeration_0)

        loss = -flow_dist.condition(encodings).log_prob(extract_1).mean()

Sizes:

encodings.shape

torch.Size([3, 32])

extract_1.shape

torch.Size([3, 2000, 6])

extract_0.shape

torch.Size([6000, 6])

That should be everything, let me know if you need anything else.

kinda hard to parse all that code but i’m guessing you just need some unsqueezing to make sure things broadcast correctly? maybe something like

flow_dist.condition(encodings.unsqueeze(-2)).log_prob(extract_1)

That seems to do the trick, gives a ~3x performance increase too! Thanks for bearing with me.