Enumeration and subsampling: expected all enumerated sample sites to share common poutine.scale

I am trying to implement subsampling into a Pyro model that also features enumeration, but I encounter an error regarding different scales that I do not know how to resolve.

My model has two plates (plate1 and plate2). I am trying to enumerate over plate2 while subsampling over plate1. When I call svi.step, I get the following error: ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 2 different scales.

Below, I provide a toy example that should be able to reproduce the error in full detail.

When I examine the Pyro source code, it seems the problem lies in the _get_common_scales function from traceenum_elbo.py. I think, for some reason, two scales are present when pyro attempts to weigh the loss/likelihood by the subsample size. One of the scales=1 and the other is equal to the fraction of the subsample size / total data size.

The closest post on the forum I could find was the following, however the solution (to use TraceEnum_ELBO) does not apply here, as I am already using it!

Any suggestions on how to resolve this error and get the model to work would be greatly appreciated!!

The code to reproduce the error is below:
(the torch_fourier_basis function might seem confusing but it is not the source of the bug, I believe)

from pyro.infer import TraceEnum_ELBO, config_enumerate
import numpy as np
import torch
import pyro
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
import pyro.distributions as dist
from pyro.distributions import constraints
from pyro import poutine
import collections

def torch_fourier_basis(phase, num_harmonics):
    idx_harm = torch.concat(
        [torch.tensor([0.0]), torch.repeat_interleave(torch.arange(1, 1 + num_harmonics), 2)]
    )
    sin_cos_bool = torch.tensor([False] + [False, True] * num_harmonics)
    base_bool = torch.tensor([True] + [False] * (num_harmonics * 2))
    return torch.where(base_bool, torch.tensor(1.0, dtype=torch.float32),
                       torch.where(sin_cos_bool, 
                                   torch.sin(idx_harm * phase.unsqueeze(-1)),
                                   torch.cos(idx_harm * phase.unsqueeze(-1)),
                                  ))

@config_enumerate
def model_enum_minibatch(data, ind=torch.arange(0, 200)):
    # Plates initialization
    plate1 = pyro.plate("plate1", 1000, subsample=ind, dim=-1)
    plate2 = pyro.plate("plate2", 100, dim=-2)
    
    with plate2:
        x = pyro.sample("x", dist.Normal(0.5, 1.)).unsqueeze(-1)
        
        p = pyro.sample('p', dist.Beta(0.05, 0.95))
        periodic = pyro.sample('periodic', dist.Bernoulli(p), infer={"enumerate": "parallel"})
        
    with plate1 as ind:
        phi = pyro.sample("ϕxy", dist.Normal(0., 2*np.pi))

    y = torch_fourier_basis(phi.squeeze(), num_harmonics=1) 
    ElogS = torch.where(periodic==1, (x * y).sum(-1), (x[:, :, 0] * y[:, 0]))
    
    with plate1, plate2:
            pyro.sample("data", dist.GammaPoisson(0.5, 1.0 / (0.5 * torch.exp(ElogS))), obs=data[:, ind])
    
def clipped_sigmoid(x):
    finfo = torch.finfo(x.dtype)
    y = x / 1.
    z = torch.clamp(y, min=finfo.min+10*finfo.eps, max=finfo.max-10*finfo.eps)
    return torch.clamp(torch.sigmoid(z), min=finfo.eps, max=1.-finfo.eps)

@config_enumerate
def guide_enum_minibatch(data, ind=torch.arange(0, 200)):
    # Plates initialization
    plate1 = pyro.plate("plate1", 1000, subsample=ind, dim=-1)
    plate2 = pyro.plate("plate2", 100, dim=-2)
    
    x_locs = pyro.param("x_locs", torch.tensor(0.0)).unsqueeze(-1)
    x_scales = pyro.param("x_scales", torch.tensor(1.0), constraint=constraints.positive)
    
    with plate1 as ind:
        phi_locs = pyro.param("ϕxy_locs", torch.tensor(1.0))

    avg_p =  torch.tensor(0.30 / (0.30 + 0.70))
    logit_avg = (torch.log(avg_p / (1-avg_p)))
    logit_locs = pyro.param("logit_locs", torch.zeros((100, 1)))
    
    with plate2:
        x = pyro.sample("x", dist.Normal(x_locs, x_scales))
        p = pyro.sample('p', dist.Delta(clipped_sigmoid(logit_locs+(logit_avg*1.))))
        
    with plate1 as ind:
        phi = pyro.sample("ϕxy", dist.Normal(phi_locs, 1.0))

random_data = torch.tensor(np.random.poisson(lam=1, size=(100, 1000))).float()

pyro.clear_param_store()

num_steps = 3000
initial_lr = 0.03
final_lr = 0.005
gamma = final_lr / initial_lr
lrd = gamma ** (1 / num_steps)
adam = pyro.optim.ClippedAdam({'lr': initial_lr, 'lrd': lrd, 'betas': (0.80, 0.99)})

svi = pyro.infer.SVI(model_enum_minibatch, guide_enum_minibatch, optim=adam, 
                     loss=pyro.infer.TraceEnum_ELBO(num_particles = 1))

losses = []
verbose=True
intermediate_output = []
plotting_data = collections.defaultdict(list)
for step in range(num_steps):
    loss = svi.step(random_data, ind=torch.randint(low=0, high=random_data.shape[1], size=(200,)))
    losses.append(loss)

And the complete trace of the error:

ValueError                                Traceback (most recent call last)
Input In [39], in <cell line: 5>()
      4 plotting_data = collections.defaultdict(list)
      5 for step in range(num_steps):
      6     #print(step)
----> 7     loss = svi.step(random_data, ind=torch.randint(low=0, high=random_data.shape[1], size=(200,)))
      8     losses.append(loss)
     10     if verbose:

File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/svi.py:145, in SVI.step(self, *args, **kwargs)
    143 # get loss and compute gradients
    144 with poutine.trace(param_only=True) as param_capture:
--> 145     loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    147 params = set(
    148     site["value"].unconstrained() for site in param_capture.trace.nodes.values()
    149 )
    151 # actually perform gradient steps
    152 # torch.optim objects gets instantiated for any params that haven't been seen yet

File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:452, in TraceEnum_ELBO.loss_and_grads(self, model, guide, *args, **kwargs)
    450 elbo = 0.0
    451 for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
--> 452     elbo_particle = _compute_dice_elbo(model_trace, guide_trace)
    453     if is_identically_zero(elbo_particle):
    454         continue

File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:180, in _compute_dice_elbo(model_trace, guide_trace)
    178 def _compute_dice_elbo(model_trace, guide_trace):
    179     # Accumulate marginal model costs.
--> 180     marginal_costs, log_factors, ordering, sum_dims, scale = _compute_model_factors(
    181         model_trace, guide_trace
    182     )
    183     if log_factors:
    184         dim_to_size = {}

File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:174, in _compute_model_factors(model_trace, guide_trace)
    172         log_factors.setdefault(t, []).append(logprob)
    173         scales.append(site["scale"])
--> 174 scale = _get_common_scale(scales)
    175 return marginal_costs, log_factors, ordering, enum_dims, scale

File ~/anaconda3/envs/python39/lib/python3.9/contextlib.py:79, in ContextDecorator.__call__.<locals>.inner(*args, **kwds)
     76 @wraps(func)
     77 def inner(*args, **kwds):
     78     with self._recreate_cm():
---> 79         return func(*args, **kwds)

File ~/anaconda3/envs/python39/lib/python3.9/site-packages/pyro/infer/traceenum_elbo.py:43, in _get_common_scale(scales)
     41     scales_set.add(float(scale))
     42 if len(scales_set) != 1:
---> 43     raise ValueError(
     44         "Expected all enumerated sample sites to share a common poutine.scale, "
     45         "but found {} different scales.".format(len(scales_set))
     46     )
     47 return scales[1]

ValueError: Expected all enumerated sample sites to share a common poutine.scale, but found 2 different scales.

It is just not supported in Pyro yet. You cannot subsample variable in inner plate if it depends on enumerated variable in outer plate. Simpler example that will give the same error:

import torch
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate

data = torch.tensor([0., 1., 10., 11., 12.])

K = 2  # Fixed number of components.

@config_enumerate
def model(data):
    # Global variables.

    locs = torch.tensor([1., 10.])
    assignment = pyro.sample('assignment', dist.Categorical(torch.ones(2)))
    with pyro.plate('data', len(data), subsample_size=2) as ind:
        # Local variables.
        # cannot subsample here
        pyro.sample('obs', dist.Normal(locs[assignment], 1.), obs=data[ind])

def guide(data):
    pass

elbo = TraceEnum_ELBO(max_plate_nesting=1)
loss = elbo.loss(model, guide, data)

If you are willing to try out NumPyro it has a new TraceEnum_ELBO that supports cases like this.

You can also open an issue with feature request to support cases like this in Pyro.

1 Like

It actually doesn’t work in NumPyro neither, sorry for misleading.

Thank you for the answer @ordabayev! It’s too bad this isn’t supported yet in Pyro or NumPyro, I will open an issue with a feature request so that maybe in the future Pyro supports such cases.

Oh I’m having the same problem. And I didn’t find an issue in the pyro-ppl/pyro repository, should I open one? :face_with_peeking_eye:

Yeah, please go ahead.

I’ve been working on this feature in NumPyro’s TraceEnum_ELBO (PR). It requires some changes in funsor package and maybe later can be ported to Pyro as well.

Oh thanks, that’s great! Cheers!
(As I’m trying to scale up models previously tested on low-dimensional toy examples to larger datasets and high-dimensional data, it really helps to use subsampling. Lots of time can be saved :grimacing: