Error with Predictive when using JitTrace_ELBO()

Next step in the project I’ve posted about several times is trying to speed up performance. I’m trying to use JitTrace_ELBO() instead of Trace_ELBO() to see if I get significant speed up.

Model:

class BayesianRegression_LogGamma_shape_zeroInf_thetaFunc3(PyroModule):
    @name_count
    def __init__(self, in_features, mu_l, sh_l, th_l, out_features = 1):
        super().__init__()
        
        # parameter names list
        self.parameter_names = []
        
        layers = []
        for i in range(len(mu_l)-1):
            #print(i)
            layers.append(('mu_fc' + str(i), nn.Linear(mu_l[i], mu_l[i+1])))
            if i != (len(mu_l)-2): layers.append(('mu_ReLU' + str(i), nn.ReLU()))
        mu = OrderedDict(layers)
        self.mu = nn.Sequential(mu)
        
        for name, param in self.mu.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.mu)
        for m in self.mu.modules():
            if m._pyro_name == 'mu_fc' + str(len(mu_l)-2):
                for name, value in list(m.named_parameters(recurse=False)):
                    setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))
                
        layers = []
        for i in range(len(sh_l)-1):
            layers.append(('sh_fc' + str(i), nn.Linear(sh_l[i], sh_l[i+1])))
            if i != (len(sh_l)-2): layers.append(('sh_ReLU' + str(i), nn.ReLU()))
        shape = OrderedDict(layers)
        self.shape = nn.Sequential(shape)
        
        for name, param in self.shape.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.shape)
        for m in self.shape.modules():
            if m._pyro_name == 'sh_fc' + str(len(sh_l)-2):
                for name, value in list(m.named_parameters(recurse=False)):
                    setattr(m, name, PyroSample(prior=dist.Laplace(0., 1.).expand(value.shape).to_event(value.dim())))
        
        layers = []
        for i in range(len(th_l)-1):
            layers.append(('th_fc' + str(i), nn.Linear(th_l[i], th_l[i+1])))
            if i != (len(th_l)-2): layers.append(('th_ReLU' + str(i), nn.ReLU()))
        layers.append(('theta_Sigmoid', nn.Sigmoid()))
        theta = OrderedDict(layers)
        self.theta = nn.Sequential(theta)
        
        for name, param in self.theta.named_parameters():
            self.parameter_names.append(name)
        
        pyro.nn.module.to_pyro_module_(self.theta)
        for m in self.theta.modules():
            if m._pyro_name == 'th_fc' + str(len(th_l)-2):
                for name, value in list(m.named_parameters(recurse=False)):
                    setattr(m, name, PyroSample(prior=dist.Laplace(0., 1.).expand(value.shape).to_event(value.dim())))

    def forward(self, x, y=None):
        
        x = x.reshape(-1, 2)
        mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001)
        shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)
        theta = self.theta(x).squeeze(-1)
        
        # will need to add GPU device
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)
        return  torch.cat((mu, shape, theta), 0)

GammaHurdle (previously discussed here):

class GammaHurdle(dist.torch_distribution.TorchDistribution):
    
    # theta: probability of a zero
    arg_constraints = {'concentration': torch.distributions.constraints.positive, 
                       'rate': torch.distributions.constraints.positive, 
                       'theta': torch.distributions.constraints.interval(0., 1.)}
    #support = torch.distributions.constraints._Interval(0, float("inf"))
    support = torch.distributions.constraints.greater_than_eq(0.)
    has_rsample = True
    
    def __init__(self, concentration, rate, theta, validate_args=None):
        self.concentration, self.rate, self.theta = broadcast_all(concentration, rate, theta)
        if isinstance(concentration, Number) and isinstance(rate, Number):
            batch_shape = torch.Size()
        else:
            batch_shape = self.concentration.size()
        super(GammaHurdle, self).__init__(batch_shape, validate_args=validate_args)
            
    def log_prob(self, value):
        value = torch.as_tensor(value, dtype=self.rate.dtype, device=device) #self.rate.device
        #print('log_prob value',(torch.min(value), torch.max(value)))
        if self._validate_args:
            self._validate_sample(value)
        
        safe_value = torch.where(value > 0., value, torch.tensor(1.))
        ret = torch.where(value > 0., 
                           (torch.log(1. - self.theta) + self.concentration * torch.log(self.rate) +
                            (self.concentration - 1.) * torch.log(safe_value) -
                            self.rate * safe_value - torch.lgamma(self.concentration)), 
                            torch.log(self.theta))

        return ret#.clamp(min=-5., max=1)
    
    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        mask = torch.rand(shape) < self.theta.expand(shape)        
        value = torch._standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
        value[mask] = 0.
        #print('rsample', (torch.min(ret), torch.max(ret)))
        value.detach().clamp_(min=torch.finfo(value.dtype).tiny)  # do not record in autograd graph
        return value
    
    def expand(self, batch_shape, _instance=None):
        new = self._get_checked_instance(GammaHurdle, _instance)
        batch_shape = torch.Size(batch_shape)
        new.concentration = self.concentration.expand(batch_shape)
        new.rate = self.rate.expand(batch_shape) 
        new.theta = self.theta.expand(batch_shape)
        super(GammaHurdle, new).__init__(batch_shape, validate_args=False)
        new._validate_args = self._validate_args
        return new

Training works:

model = BayesianRegression_LogGamma_shape_zeroInf_thetaFunc3(2, mu_l = [2, 4, 1], sh_l = [2, 4, 1], th_l = [2, 4, 1], out_features = 1)
guide = AutoDiagonalNormal(model)
adam = pyro.optim.ClippedAdam({"lr": 0.01, 'betas': (.95, .999), 'weight_decay' : .2, 'clip_norm' : 5.})
if USE_JIT:
    guide(torch.zeros([2000, 2]))  # Do any lazy initialization before compiling.
    svi = SVI(model, guide, adam, loss=JitTrace_ELBO())
train_and_evaluate_SVI(svi=svi,  criterion = Trace_ELBO(), model = model, guide = guide, bs = BATCH_SIZE, ne = 10, lr=0.01)

But I run into an error when trying to use Predictive.

rs = ["obs", "_RETURN"]
predictive = Predictive(model=model, guide=guide, num_samples = 500, return_sites = rs)
with converter_val.make_torch_dataloader(batch_size=5e5) as val_dataloader:
  with torch.no_grad():
    val_dataloader_iter = iter(val_dataloader)
    pd_batch = next(val_dataloader_iter)
    pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()    
    inputs = pd_batch['features'].to(device)
    labels = pd_batch[y_name].to(device)
    samples = predictive(inputs)

Error:

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
RuntimeError                              Traceback (most recent call last)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

<command-3087964165514358> in forward(self, x, y)
     63         x = x.reshape(-1, 2).to(device)
---> 64         mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
     65         shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    140         for module in self:
--> 141             input = module(input)
    142         return input

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 

/databricks/python/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<command-3621413792290015> in <module>
     14     inputs = pd_batch['features'].to(device)
     15     labels = pd_batch[y_name].to(device)
---> 16     samples = predictive(inputs)
     17 #     pred_summary = summary(samples)
     18 #     param_summary = parameters_summary_spark(pred_summary, inputs, labels)

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
    271                 model_kwargs=kwargs,
    272             )
--> 273         return _predictive(
    274             self.model,
    275             posterior_samples,

/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
    125 
    126     if not parallel:
--> 127         return _predictive_sequential(
    128             model,
    129             posterior_samples,

/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
     46     ]
     47     for i in range(num_samples):
---> 48         trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(
     49             *model_args, **model_kwargs
     50         )

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    178                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    179                 exc = exc.with_traceback(traceback)
--> 180                 raise exc from e
    181             self.msngr.trace.add_node(
    182                 "_RETURN", name="_RETURN", type="return", value=ret

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26         def decorate_context(*args, **kwargs):
     27             with self.__class__():
---> 28                 return func(*args, **kwargs)
     29         return cast(F, decorate_context)
     30 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<command-3087964165514358> in forward(self, x, y)
     62 
     63         x = x.reshape(-1, 2).to(device)
---> 64         mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
     65         shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)
     66         theta = self.theta(x).squeeze(-1)#.to(device)

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

/databricks/python/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
        Trace Shapes:        
         Param Sites:        
     mu.mu_fc0.weight 4 2    
       mu.mu_fc0.bias   4    
        Sample Sites:        
mu.mu_fc1.weight dist   | 1 4
                value 1 | 1 4
  mu.mu_fc1.bias dist   | 1  
                value 1 | 1  

I’m trying to figure out where t() is even being used but I must be missing it.

I can also successfully run model.mu(inputs).squeeze(-1).exp() without error.

Hi @yoshy, I believe torch.nn.Linear is finicky about batch shape. My guess is under the hood there is some jit translation of nn.Linear() to something involving .t() rather than the more robust torch.transpose().

As a workaround you might try replacing torch.nn.Linear with a hand-implementation of a Linear module. I usually avoid nn.Linear because of its batching issues.

Also I’ve had limited luck with JitTrace_ELBO in torch>=1.10 due to an increasingly picky jit compiler. I’ve had better luck speeding up inference by:

  • subsampling data where possible in large models
  • batching multiple particles in small models
  • tuning the learning rate of individual parameters to keep overall learning rate high
  • using an exponential learning rate schedule via ClippedAdam

good luck!

As always, thank you so much for being so helpful. I’ll look into writing my own Linear module then.

I forgot to mention I received some warnings. I need to figure out if jit can even be used in my case.

warnings.warn(
<command-1625998740881369>:20: TracerWarning: torch.as_tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  value = torch.as_tensor(value, dtype=self.rate.dtype, device=device) #self.rate.device
<command-1625998740881369>:25: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
  safe_value = torch.where(value > 0., value, torch.tensor(1.))
  • subsampling - I have been trying this as well but didn’t observe any speed increase, so I assume I have mistake somewhere. I used the following:
def forward(self, x, y=None):
        
        x = x.reshape(-1, 2).to(device)
        mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
        shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)
        theta = self.theta(x).squeeze(-1)#.to(device)
        
        if subsample_svi_N is None:
          with pyro.plate("data", x.shape[0], device = device):
              obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)
        else:
          with pyro.plate("data", x.shape[0], subsample_size = subsample_svi_N, device = device) as ind:
              obs = pyro.sample("obs", GammaHurdle(concentration = shape.index_select(0, ind), 
                                                   rate = shape.index_select(0, ind) / mu.index_select(0, ind), 
                                                   theta = theta.index_select(0, ind)), 
                                obs=y.index_select(0, ind))
        
        return  torch.cat((mu, shape, theta), 0)         

Where subsample_svi_N was a number < x.shape[0] , i.e. less than the batch size.

  • batching multiple particles in small models - Can you point me to the documentation for this?

I’ll check out those other ideas as well. Thanks!

You can use Trace_ELBO with multiple particles:

elbo = Trace_ELBO(num_particles=100, vectorize_particles=True)

This really only helps with small models whose tensors are so small that PyTorch op overhead dominates. By using multiple particles you can saturate your CPU or GPU more easily. This works especially well on GPU.

t() error again when trying Trace_ELBO(num_particles=100, vectorize_particles=True). I think I’ll need to review Tensor shapes in Pyro — Pyro Tutorials 1.8.0 documentation and figure out how to handle shapes properly.

Full error:

t() expects a tensor with <= 2 dimensions, but self is 4D
RuntimeError                              Traceback (most recent call last)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

<command-3087964165514358> in forward(self, x, y)
     63         x = x.reshape(-1, 2).to(device)
---> 64         mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
     65         shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    140         for module in self:
--> 141             input = module(input)
    142         return input

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 

/databricks/python/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<command-621960980498712> in <module>
     14 #   else:
     15   svi = SVI(model, guide, adam, loss=Trace_ELBO(num_particles=100, vectorize_particles=True))
---> 16   train_and_evaluate_SVI(svi=svi,  criterion = Trace_ELBO(num_particles=100, vectorize_particles=True), model = model, guide = guide, bs = BATCH_SIZE, ne = 10, lr=0.01)
     17   torch.save({"model": model.state_dict(), "guide": guide}, f = "/dbfs/dbfs/tmp/model_15.pyro")
     18   pyro.get_param_store().save('/dbfs/dbfs/tmp/Models/saved_params_15.pyro')

<command-1625998740881370> in train_and_evaluate_SVI(svi, criterion, model, guide, bs, ne, lr)
     19           if (epoch + 1) % 10 == 0: print('-' * 10)
     20 
---> 21           train_loss = train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device)
     22           #evaluate(self, guide, model, criterion, val_dataloader_iter, validation_steps, device, metric_agg_fn=None):
     23           if (epoch + 1) % 10 == 0: val_loss = evaluate(guide, model, criterion, val_dataloader_iter, validation_steps, device) #val_loss =

<command-1625998740881371> in train_one_epoch_SVI(svi, train_dataloader_iter, steps_per_epoch, epoch, device, metric_agg_fn)
     73       labels = pd_batch[y_name].to(device)#.double()# + .01
     74       #labels = labels.to(device)
---> 75       loss = svi.step(inputs, labels)
     76       #loss = 1.0
     77 

/databricks/python/lib/python3.8/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    143         # get loss and compute gradients
    144         with poutine.trace(param_only=True) as param_capture:
--> 145             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    146 
    147         params = set(

/databricks/python/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    138         loss = 0.0
    139         # grab a trace from the generator
--> 140         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142                 model_trace, guide_trace

/databricks/python/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    177             if self.max_plate_nesting == float("inf"):
    178                 self._guess_max_plate_nesting(model, guide, args, kwargs)
--> 179             yield self._get_vectorized_trace(model, guide, args, kwargs)
    180         else:
    181             for i in range(self.num_particles):

/databricks/python/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_vectorized_trace(self, model, guide, args, kwargs)
    154         and guide.
    155         """
--> 156         return self._get_trace(
    157             self._vectorized_num_particles(model),
    158             self._vectorized_num_particles(guide),

/databricks/python/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     55         against it.
     56         """
---> 57         model_trace, guide_trace = get_importance_trace(
     58             "flat", self.max_plate_nesting, model, guide, args, kwargs
     59         )

/databricks/python/lib/python3.8/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     63         if detach:
     64             guide_trace.detach_()
---> 65         model_trace = poutine.trace(
     66             poutine.replay(model, trace=guide_trace), graph_type=graph_type
     67         ).get_trace(*args, **kwargs)

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    178                 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
    179                 exc = exc.with_traceback(traceback)
--> 180                 raise exc from e
    181             self.msngr.trace.add_node(
    182                 "_RETURN", name="_RETURN", type="return", value=ret

/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)
     13 
     14 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

<command-3087964165514358> in forward(self, x, y)
     62 
     63         x = x.reshape(-1, 2).to(device)
---> 64         mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
     65         shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)
     66         theta = self.theta(x).squeeze(-1)#.to(device)

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
    139     def forward(self, input):
    140         for module in self:
--> 141             input = module(input)
    142         return input
    143 

/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
    424     def __call__(self, *args, **kwargs):
    425         with self._pyro_context:
--> 426             return super().__call__(*args, **kwargs)
    427 
    428     def __getattr__(self, name):

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1100         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1101                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102             return forward_call(*input, **kwargs)
   1103         # Do not call functions when jit is used
   1104         full_backward_hooks, non_full_backward_hooks = [], []

/databricks/python/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
    101 
    102     def forward(self, input: Tensor) -> Tensor:
--> 103         return F.linear(input, self.weight, self.bias)
    104 
    105     def extra_repr(self) -> str:

/databricks/python/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1846     if has_torch_function_variadic(input, weight, bias):
   1847         return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848     return torch._C._nn.linear(input, weight, bias)
   1849 
   1850 

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 4D
        Trace Shapes:            
         Param Sites:            
     mu.mu_fc0.weight     4 2    
       mu.mu_fc0.bias       4    
        Sample Sites:            
mu.mu_fc1.weight dist 100 1 | 1 4
                value 100 1 | 1 4
  mu.mu_fc1.bias dist 100 1 | 1  
                value 100 1 | 1  

Again I suspect this is a shortcoming of torch.nn.Linear not properly supporting batching, and could be worked around with a hand implementation.

That is top on my (very long) list of things to learn and try :slight_smile:

Following up on this, I’ve written my implementation of nn.Linear and I’ve added some print statements to help me understand the inner workings of SVI.

class myLinear(nn.Module):
  def __init__(self, in_features, out_features, bias=True):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features
    self.bias = bias
    self.weight = torch.nn.Parameter(torch.randn(out_features, in_features))
    self.bias = torch.nn.Parameter(torch.randn(out_features))
       
  def forward(self, input):
    print('myLin weight', self.weight.shape, self.weight)
    print('myLin input', input.shape)
    x, y = input.shape
    if y != self.in_features:
        sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
    output = input @ self.weight.t() + self.bias
    return output

For simplicity, I’ll cut out the data pertaining to shape and theta in my model since they have the same structure as mu.

My input size is [2000, 2] (batch size = 2000), there is 1 hidden layer with 4 neurons, and final output is 1 value. The connection between the hidden layer and output is Bayesian. The connections from input to hidden are not Bayesian. Here’s what I’m seeing, with my comments added in with #.

# mu layers
myLin weight torch.Size([4, 2]) Parameter containing:
tensor([[ 1.7847, -1.0231],
        [ 0.3413,  0.6574],
        [ 0.7365,  0.0026],
        [-1.3468,  1.9001]], requires_grad=True)
myLin input torch.Size([2000, 2])
myLin weight torch.Size([1, 4]) tensor([[-2.1506, -1.0490,  1.1475, -0.6664]])
myLin input torch.Size([2000, 4])
# ... similar print outs for shape and theta

# now the guide? input again goes through mu layers
myLin weight torch.Size([4, 2]) Parameter containing:
tensor([[ 1.7847, -1.0231],
        [ 0.3413,  0.6574],
        [ 0.7365,  0.0026],
        [-1.3468,  1.9001]], requires_grad=True)
myLin input torch.Size([2000, 2])
myLin weight torch.Size([1, 4]) tensor([[-2.2078, -1.0301,  1.1036, -0.6417]], grad_fn=<ExpandBackward0>)
# again shape and theta have similar prints

These print values make sense to me. The input is sent through the model and the guide, so seeing the print outs for mu twice makes sense. The [4, 2] weight tensors are the same because that connection is not Bayesian. The [1, 4] weight tensors are different because that connection is Bayesian. However, there is a third round of print outs for mu and the other parameters. Again, first weights tensor is the same as before. Now though, the Bayesian weights are [100, 1, 1, 4]. 100 comes from Trace_ELBO(num_particles=100, vectorize_particles=True).

# mu again (this time num_particles affects the Bayesian pyro layer)
myLin input torch.Size([2000, 4])
myLin weight torch.Size([4, 2]) Parameter containing:
tensor([[ 1.7847, -1.0231],
        [ 0.3413,  0.6574],
        [ 0.7365,  0.0026],
        [-1.3468,  1.9001]], requires_grad=True)
myLin input torch.Size([2000, 2])
myLin weight torch.Size([100, 1, 1, 4]) 
tensor([[[[-2.0552, -1.0621,  1.0664, -0.8159]]],
        [[[-1.9348, -1.0332,  1.0054, -0.6918]]],
... omitted for brevity
        [[[-2.0671, -1.0431,  1.1718, -0.8851]]],
        [[[-2.0447, -1.0798,  1.2346, -0.7361]]]], grad_fn=<ExpandBackward0>)
myLin input torch.Size([2000, 4])

First actual question: why are there 3 calls to mu? To calculate ELBO, I expected the 100 particles to be used in the second calculation for the guide.

First attempt at a fix. I updated myLinear class to squeeze() weights to remove the 1D dimensions in [100, 1, 1, 4], so it is instead [100, 4].

class myLinear(nn.Module):
  ...     
  def forward(self, input):
    print('myLin weight', self.weight.shape, self.weight)
    print('myLin input', input.shape)
    x, y = input.shape
    if y != self.in_features:
        sys.exit(f'Wrong Input Features. Please use tensor with {self.in_features} Input Features')
    output = input @ self.weight.squeeze().t() + self.bias
    return output

This allows things to run until I get the error:

Shape mismatch inside plate('data') at site obs dim -1, 2000 vs 100

ValueError: Shape mismatch inside plate('data') at site obs dim -1, 2000 vs 100
           Trace Shapes:               
            Param Sites:               
        mu.mu_fc0.weight        4 2    
          mu.mu_fc0.bias          4    
     shape.sh_fc0.weight        4 2    
       shape.sh_fc0.bias          4    
     theta.th_fc0.weight        4 2    
       theta.th_fc0.bias          4    
           Sample Sites:               
   mu.mu_fc1.weight dist 100    1 | 1 4
                   value 100    1 | 1 4
     mu.mu_fc1.bias dist 100    1 | 1  
                   value 100    1 | 1  
shape.sh_fc1.weight dist 100    1 | 1 4
                   value 100    1 | 1 4
  shape.sh_fc1.bias dist 100    1 | 1  
                   value 100    1 | 1  
theta.th_fc1.weight dist 100    1 | 1 4
                   value 100    1 | 1 4
  theta.th_fc1.bias dist 100    1 | 1  
                   value 100    1 | 1  
               data dist          |    
                   value     2000 |    

Updating my forward method for BayesianRegression_LogGamma_shape_zeroInf_thetaFunc3, I get a little farther, but a new error. I added in .to_event(1) to the obs sample.

def forward(self, x, y=None):
        
        x = x.reshape(-1, 2).to(device)
        mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
        shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)
        theta = self.theta(x).squeeze(-1)#.to(device)
        
        if subsample_svi_N is None:
          with pyro.plate("data", x.shape[0], device = device):
              obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta).to_event(1), obs=y)
        else:
          with pyro.plate("data", x.shape[0], subsample_size = subsample_svi_N, device = device) as ind:
              obs = pyro.sample("obs", GammaHurdle(concentration = shape.index_select(0, ind), 
                                                   rate = shape.index_select(0, ind) / mu.index_select(0, ind), 
                                                   theta = theta.index_select(0, ind)), 
                                obs=y.index_select(0, ind))
        
        return  torch.cat((mu, shape, theta), 0)         

The new error is shown below. I’m definitely messing up shapes somewhere, but I’m not clear where.

ValueError: Error while computing log_prob at site 'obs':
Value is not broadcastable with batch_shape+event_shape: torch.Size([2000]) vs torch.Size([100, 2000, 100]).
           Trace Shapes:                  
            Param Sites:                  
        mu.mu_fc0.weight        4 2       
          mu.mu_fc0.bias          4       
     shape.sh_fc0.weight        4 2       
       shape.sh_fc0.bias          4       
     theta.th_fc0.weight        4 2       
       theta.th_fc0.bias          4       
           Sample Sites:                  
   mu.mu_fc1.weight dist 100    1 |    1 4
                   value 100    1 |    1 4
                log_prob 100    1 |       
     mu.mu_fc1.bias dist 100    1 |    1  
                   value 100    1 |    1  
                log_prob 100    1 |       
shape.sh_fc1.weight dist 100    1 |    1 4
                   value 100    1 |    1 4
                log_prob 100    1 |       
  shape.sh_fc1.bias dist 100    1 |    1  
                   value 100    1 |    1  
                log_prob 100    1 |       
theta.th_fc1.weight dist 100    1 |    1 4
                   value 100    1 |    1 4
                log_prob 100    1 |       
  theta.th_fc1.bias dist 100    1 |    1  
                   value 100    1 |    1  
                log_prob 100    1 |       
                obs dist 100 2000 |  100  
                   value          | 2000