I am trying to implement subsampling into a Pyro model that also features enumeration, but I encounter an error regarding different scales that I do not know how to resolve.
My model has two plates (plate1 and plate2). I am trying to enumerate over plate2 while subsampling over plate1. When I call svi.step
, I get the following error: ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 2 different scales.
Below, I provide a toy example that should be able to reproduce the error in full detail.
When I examine the Pyro source code, it seems the problem lies in the _get_common_scales
function from traceenum_elbo.py
. I think, for some reason, two scales are present when pyro attempts to weigh the loss/likelihood by the subsample size. One of the scales=1 and the other is equal to the fraction of the subsample size / total data size.
The closest post on the forum I could find was the following, however the solution (to use TraceEnum_ELBO) does not apply here, as I am already using it!
Any suggestions on how to resolve this error and get the model to work would be greatly appreciated!!
The code to reproduce the error is below:
(the torch_fourier_basis function might seem confusing but it is not the source of the bug, I believe)
from pyro.infer import TraceEnum_ELBO, config_enumerate
import numpy as np
import torch
import pyro
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro import poutine
import collections
def torch_fourier_basis(phase, num_harmonics):
idx_harm = torch.concat(
[torch.tensor([0.0]), torch.repeat_interleave(torch.arange(1, 1 + num_harmonics), 2)]
)
sin_cos_bool = torch.tensor([False] + [False, True] * num_harmonics)
base_bool = torch.tensor([True] + [False] * (num_harmonics * 2))
return torch.where(base_bool, torch.tensor(1.0, dtype=torch.float32),
torch.where(sin_cos_bool,
torch.sin(idx_harm * phase.unsqueeze(-1)),
torch.cos(idx_harm * phase.unsqueeze(-1)),
))
@config_enumerate
def model_enum_minibatch(data, ind=torch.arange(0, 200)):
# Plates initialization
plate1 = pyro.plate("plate1", 1000, subsample=ind, dim=-1)
plate2 = pyro.plate("plate2", 100, dim=-2)
with plate2:
x = pyro.sample("x", dist.Normal(0.5, 1.)).unsqueeze(-1)
p = pyro.sample('p', dist.Beta(0.05, 0.95))
periodic = pyro.sample('periodic', dist.Bernoulli(p), infer={"enumerate": "parallel"})
with plate1 as ind:
phi = pyro.sample("ϕxy", dist.Normal(0., 2*np.pi))
y = torch_fourier_basis(phi.squeeze(), num_harmonics=1)
ElogS = torch.where(periodic==1, (x * y).sum(-1), (x[:, :, 0] * y[:, 0]))
with plate1, plate2:
pyro.sample("data", dist.GammaPoisson(0.5, 1.0 / (0.5 * torch.exp(ElogS))), obs=data[:, ind])
def clipped_sigmoid(x):
finfo = torch.finfo(x.dtype)
y = x / 1.
z = torch.clamp(y, min=finfo.min+10*finfo.eps, max=finfo.max-10*finfo.eps)
return torch.clamp(torch.sigmoid(z), min=finfo.eps, max=1.-finfo.eps)
@config_enumerate
def guide_enum_minibatch(data, ind=torch.arange(0, 200)):
# Plates initialization
plate1 = pyro.plate("plate1", 1000, subsample=ind, dim=-1)
plate2 = pyro.plate("plate2", 100, dim=-2)
x_locs = pyro.param("x_locs", torch.tensor(0.0)).unsqueeze(-1)
x_scales = pyro.param("x_scales", torch.tensor(1.0), constraint=constraints.positive)
with plate1 as ind:
phi_locs = pyro.param("ϕxy_locs", torch.tensor(1.0))
avg_p = torch.tensor(0.30 / (0.30 + 0.70))
logit_avg = (torch.log(avg_p / (1-avg_p)))
logit_locs = pyro.param("logit_locs", torch.zeros((100, 1)))
with plate2:
x = pyro.sample("x", dist.Normal(x_locs, x_scales))
p = pyro.sample('p', dist.Delta(clipped_sigmoid(logit_locs+(logit_avg*1.))))
with plate1 as ind:
phi = pyro.sample("ϕxy", dist.Normal(phi_locs, 1.0))
random_data = torch.tensor(np.random.poisson(lam=1, size=(100, 1000))).float()
pyro.clear_param_store()
num_steps = 3000
initial_lr = 0.03
final_lr = 0.005
gamma = final_lr / initial_lr
lrd = gamma ** (1 / num_steps)
adam = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd, 'betas': (0.80, 0.99)})
svi = pyro.infer.SVI(model_enum_minibatch, guide_enum_minibatch, optim=adam,
loss=pyro.infer.TraceEnum_ELBO(num_particles = 1))
losses = []
verbose=True
intermediate_output = []
plotting_data = collections.defaultdict(list)
for step in range(num_steps):
loss = svi.step(random_data, ind=torch.randint(low=0, high=random_data.shape[1], size=(200,)))
losses.append(loss)
And the complete trace of the error:
ValueError Traceback (most recent call last)
Input In [39], in <cell line: 5>()
4 plotting_data = collections.defaultdict(list)
5 for step in range(num_steps):
6 #print(step)
----> 7 loss = svi.step(random_data, ind=torch.randint(low=0, high=random_data.shape[1], size=(200,)))
8 losses.append(loss)
10 if verbose:
File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
143 # get loss and compute gradients
144 with poutine.trace(param_only=True) as param_capture:
--> 145 loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
147 params = set(
148 site["value"].unconstrained() for site in param_capture.trace.nodes.values()
149 )
151 # actually perform gradient steps
152 # torch.optim objects gets instantiated for any params that haven't been seen yet
File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:452, in TraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
450 elbo = 0.0
451 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
--> 452 elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
453 if is_identically_zero(elbo_particle):
454 continue
File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:180, in _compute_dice_elbo(model_trace, guide_trace)
178 def _compute_dice_elbo(model_trace, guide_trace):
179 # Accumulate marginal model costs.
--> 180 marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
181 model_trace, guide_trace
182 )
183 if log_factors:
184 dim_to_size = {}
File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:174, in _compute_model_factors(model_trace, guide_trace)
172 log_factors.setdefault(t, []).append(logprob)
173 scales.append(site["scale"])
--> 174 scale = _get_common_scale(scales)
175 return marginal_costs, log_factors, ordering, enum_dims, scale
File ~/anaconda3/envs/python39/lib/python3.9/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
76 @wraps(func)
77 def inner(*args, **kwds):
78 with self._recreate_cm():
---> 79 return func(*args, **kwds)
File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:43, in _get_common_scale(scales)
41 scales_set.add(float(scale))
42 if len(scales_set) != 1:
---> 43 raise ValueError(
44 "Expected all enumerated sample sites to share a common poutine.scale, "
45 "but found {} different scales.".format(len(scales_set))
46 )
47 return scales[1]
ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 2 different scales.