Next step in the project I’ve posted about several times is trying to speed up performance. I’m trying to use JitTrace_ELBO() instead of Trace_ELBO() to see if I get significant speed up.
Model:
class BayesianRegression_LogGamma_shape_zeroInf_thetaFunc3(PyroModule):
@name_count
def __init__(self, in_features, mu_l, sh_l, th_l, out_features = 1):
super().__init__()
# parameter names list
self.parameter_names = []
layers = []
for i in range(len(mu_l)-1):
#print(i)
layers.append(('mu_fc' + str(i), nn.Linear(mu_l[i], mu_l[i+1])))
if i != (len(mu_l)-2): layers.append(('mu_ReLU' + str(i), nn.ReLU()))
mu = OrderedDict(layers)
self.mu = nn.Sequential(mu)
for name, param in self.mu.named_parameters():
self.parameter_names.append(name)
pyro.nn.module.to_pyro_module_(self.mu)
for m in self.mu.modules():
if m._pyro_name == 'mu_fc' + str(len(mu_l)-2):
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=dist.Normal(0., 1.).expand(value.shape).to_event(value.dim())))
layers = []
for i in range(len(sh_l)-1):
layers.append(('sh_fc' + str(i), nn.Linear(sh_l[i], sh_l[i+1])))
if i != (len(sh_l)-2): layers.append(('sh_ReLU' + str(i), nn.ReLU()))
shape = OrderedDict(layers)
self.shape = nn.Sequential(shape)
for name, param in self.shape.named_parameters():
self.parameter_names.append(name)
pyro.nn.module.to_pyro_module_(self.shape)
for m in self.shape.modules():
if m._pyro_name == 'sh_fc' + str(len(sh_l)-2):
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=dist.Laplace(0., 1.).expand(value.shape).to_event(value.dim())))
layers = []
for i in range(len(th_l)-1):
layers.append(('th_fc' + str(i), nn.Linear(th_l[i], th_l[i+1])))
if i != (len(th_l)-2): layers.append(('th_ReLU' + str(i), nn.ReLU()))
layers.append(('theta_Sigmoid', nn.Sigmoid()))
theta = OrderedDict(layers)
self.theta = nn.Sequential(theta)
for name, param in self.theta.named_parameters():
self.parameter_names.append(name)
pyro.nn.module.to_pyro_module_(self.theta)
for m in self.theta.modules():
if m._pyro_name == 'th_fc' + str(len(th_l)-2):
for name, value in list(m.named_parameters(recurse=False)):
setattr(m, name, PyroSample(prior=dist.Laplace(0., 1.).expand(value.shape).to_event(value.dim())))
def forward(self, x, y=None):
x = x.reshape(-1, 2)
mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001)
shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)
theta = self.theta(x).squeeze(-1)
# will need to add GPU device
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", GammaHurdle(concentration = shape, rate = shape / mu, theta = theta), obs=y)
return torch.cat((mu, shape, theta), 0)
GammaHurdle (previously discussed here):
class GammaHurdle(dist.torch_distribution.TorchDistribution):
# theta: probability of a zero
arg_constraints = {'concentration': torch.distributions.constraints.positive,
'rate': torch.distributions.constraints.positive,
'theta': torch.distributions.constraints.interval(0., 1.)}
#support = torch.distributions.constraints._Interval(0, float("inf"))
support = torch.distributions.constraints.greater_than_eq(0.)
has_rsample = True
def __init__(self, concentration, rate, theta, validate_args=None):
self.concentration, self.rate, self.theta = broadcast_all(concentration, rate, theta)
if isinstance(concentration, Number) and isinstance(rate, Number):
batch_shape = torch.Size()
else:
batch_shape = self.concentration.size()
super(GammaHurdle, self).__init__(batch_shape, validate_args=validate_args)
def log_prob(self, value):
value = torch.as_tensor(value, dtype=self.rate.dtype, device=device) #self.rate.device
#print('log_prob value',(torch.min(value), torch.max(value)))
if self._validate_args:
self._validate_sample(value)
safe_value = torch.where(value > 0., value, torch.tensor(1.))
ret = torch.where(value > 0.,
(torch.log(1. - self.theta) + self.concentration * torch.log(self.rate) +
(self.concentration - 1.) * torch.log(safe_value) -
self.rate * safe_value - torch.lgamma(self.concentration)),
torch.log(self.theta))
return ret#.clamp(min=-5., max=1)
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
mask = torch.rand(shape) < self.theta.expand(shape)
value = torch._standard_gamma(self.concentration.expand(shape)) / self.rate.expand(shape)
value[mask] = 0.
#print('rsample', (torch.min(ret), torch.max(ret)))
value.detach().clamp_(min=torch.finfo(value.dtype).tiny) # do not record in autograd graph
return value
def expand(self, batch_shape, _instance=None):
new = self._get_checked_instance(GammaHurdle, _instance)
batch_shape = torch.Size(batch_shape)
new.concentration = self.concentration.expand(batch_shape)
new.rate = self.rate.expand(batch_shape)
new.theta = self.theta.expand(batch_shape)
super(GammaHurdle, new).__init__(batch_shape, validate_args=False)
new._validate_args = self._validate_args
return new
Training works:
model = BayesianRegression_LogGamma_shape_zeroInf_thetaFunc3(2, mu_l = [2, 4, 1], sh_l = [2, 4, 1], th_l = [2, 4, 1], out_features = 1)
guide = AutoDiagonalNormal(model)
adam = pyro.optim.ClippedAdam({"lr": 0.01, 'betas': (.95, .999), 'weight_decay' : .2, 'clip_norm' : 5.})
if USE_JIT:
guide(torch.zeros([2000, 2])) # Do any lazy initialization before compiling.
svi = SVI(model, guide, adam, loss=JitTrace_ELBO())
train_and_evaluate_SVI(svi=svi, criterion = Trace_ELBO(), model = model, guide = guide, bs = BATCH_SIZE, ne = 10, lr=0.01)
But I run into an error when trying to use Predictive
.
rs = ["obs", "_RETURN"]
predictive = Predictive(model=model, guide=guide, num_samples = 500, return_sites = rs)
with converter_val.make_torch_dataloader(batch_size=5e5) as val_dataloader:
with torch.no_grad():
val_dataloader_iter = iter(val_dataloader)
pd_batch = next(val_dataloader_iter)
pd_batch['features'] = torch.transpose(torch.stack([pd_batch[x] for x in x_feat]), 0, 1)#.double()
inputs = pd_batch['features'].to(device)
labels = pd_batch[y_name].to(device)
samples = predictive(inputs)
Error:
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
RuntimeError Traceback (most recent call last)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
11 with context:
---> 12 return fn(*args, **kwargs)
13
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
<command-3087964165514358> in forward(self, x, y)
63 x = x.reshape(-1, 2).to(device)
---> 64 mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
65 shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
140 for module in self:
--> 141 input = module(input)
142 return input
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
/databricks/python/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1847 return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848 return torch._C._nn.linear(input, weight, bias)
1849
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<command-3621413792290015> in <module>
14 inputs = pd_batch['features'].to(device)
15 labels = pd_batch[y_name].to(device)
---> 16 samples = predictive(inputs)
17 # pred_summary = summary(samples)
18 # param_summary = parameters_summary_spark(pred_summary, inputs, labels)
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in forward(self, *args, **kwargs)
271 model_kwargs=kwargs,
272 )
--> 273 return _predictive(
274 self.model,
275 posterior_samples,
/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
125
126 if not parallel:
--> 127 return _predictive_sequential(
128 model,
129 posterior_samples,
/databricks/python/lib/python3.8/site-packages/pyro/infer/predictive.py in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
46 ]
47 for i in range(num_samples):
---> 48 trace = poutine.trace(poutine.condition(model, samples[i])).get_trace(
49 *model_args, **model_kwargs
50 )
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
178 exc = exc_type(u"{}\n{}".format(exc_value, shapes))
179 exc = exc.with_traceback(traceback)
--> 180 raise exc from e
181 self.msngr.trace.add_node(
182 "_RETURN", name="_RETURN", type="return", value=ret
/databricks/python/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/databricks/python/lib/python3.8/site-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 def decorate_context(*args, **kwargs):
27 with self.__class__():
---> 28 return func(*args, **kwargs)
29 return cast(F, decorate_context)
30
/databricks/python/lib/python3.8/site-packages/pyro/poutine/messenger.py in _context_wrap(context, fn, *args, **kwargs)
10 def _context_wrap(context, fn, *args, **kwargs):
11 with context:
---> 12 return fn(*args, **kwargs)
13
14
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
424 def __call__(self, *args, **kwargs):
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
428 def __getattr__(self, name):
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
<command-3087964165514358> in forward(self, x, y)
62
63 x = x.reshape(-1, 2).to(device)
---> 64 mu = self.mu(x).squeeze(-1).exp().clamp(min = .000001, max = 1e6)#.to(device)
65 shape = self.shape(x).squeeze(-1).exp().clamp(min = 0.000001)#.to(device)
66 theta = self.theta(x).squeeze(-1)#.to(device)
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
424 def __call__(self, *args, **kwargs):
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
428 def __getattr__(self, name):
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/container.py in forward(self, input)
139 def forward(self, input):
140 for module in self:
--> 141 input = module(input)
142 return input
143
/databricks/python/lib/python3.8/site-packages/pyro/nn/module.py in __call__(self, *args, **kwargs)
424 def __call__(self, *args, **kwargs):
425 with self._pyro_context:
--> 426 return super().__call__(*args, **kwargs)
427
428 def __getattr__(self, name):
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1100 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1101 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1102 return forward_call(*input, **kwargs)
1103 # Do not call functions when jit is used
1104 full_backward_hooks, non_full_backward_hooks = [], []
/databricks/python/lib/python3.8/site-packages/torch/nn/modules/linear.py in forward(self, input)
101
102 def forward(self, input: Tensor) -> Tensor:
--> 103 return F.linear(input, self.weight, self.bias)
104
105 def extra_repr(self) -> str:
/databricks/python/lib/python3.8/site-packages/torch/nn/functional.py in linear(input, weight, bias)
1846 if has_torch_function_variadic(input, weight, bias):
1847 return handle_torch_function(linear, (input, weight, bias), input, weight, bias=bias)
-> 1848 return torch._C._nn.linear(input, weight, bias)
1849
1850
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
Trace Shapes:
Param Sites:
mu.mu_fc0.weight 4 2
mu.mu_fc0.bias 4
Sample Sites:
mu.mu_fc1.weight dist | 1 4
value 1 | 1 4
mu.mu_fc1.bias dist | 1
value 1 | 1
I’m trying to figure out where t()
is even being used but I must be missing it.
I can also successfully run model.mu(inputs).squeeze(-1).exp()
without error.