Poutine.scale and ValueError:...invalid log_prob shape

I’m trying to do my own subsampling logic without pyro.plate, using poutine.scale instead. But when I add such contexts around relevant sampling statements in my model, I get strange ValueError:...invalid log_prob shape issues. As far as I can tell, this is a bug in Pyro… or at least, something that’s documented poorly enough to act as a bug.

Great. On the pyro team, our practice is also to implement subsampling outside of pyro.plate. However you’ll still need a (non-subsampling) pyro.plate to declare to Pyro that you’re using a batch dimension in some context. The following should all be equivalent:

Version 1: use pyro.plate for subsampling:

def model(full_data, batch_size):
    full_size = len(full_data)
    with pyro.plate("data", full_size, subsample_size=batch_size) as ind:
        batch = full_data[ind]
        ....

Version 2: (recommended) use pyro.plate for scaling but external code for subsampling:

def model(full_size, batch):
    with pyro.plate("data", full_size, subsample=batch):
        ....

Version 3: use pyro.plate for batch shaping, poutine.scale for scaling, external code for subsampling:

def model(full_size, batch):
    scale = float(full_size) / len(batch)
    with pyro.plate("data", len(batch)), poutin.scale(scale=scale):
        ....
2 Likes

Wow, fast answer.

So I’d have to use that same structure on the guide side too? What about if I’m using a blocked pyro.condition on the guide side?

My guess at the answers to those are: “yes”, and “don’t worry because the plate stuff only matters for svi, not just for conditioning”.

Yes, the plates should agree between model and guide.

I think it’s a problem for me to have to do things in plates.

(Aside: in my code, p indexes precincts.)

As far as I can tell, there are actually two important differences between with pyro.plate('x',P) as p: pyro.sample(f"x_{p}",...) and for p in pyro.plate('x',P): pyro.sample(f"x_{p}",...). In the former, the plate acts as a context manager; as you point out, if I’m doing scale, that’s desired behavior. But also, in the former, pyro.sample is only called once, with p as a tensor of all possible values; while in the latter, it’s called P different times, with p as a scalar.

I realize that the former offers the possibility of huge efficiency gains if done right. But for right now, I think I need to do the latter, because in my guide I need to do separate calculations for the Hessian for each p.

Eventually, I’d like to figure out how to use the former plan. But in the mean time, is there any way I can use poutine.scale with the latter plan? Do I have to say with pyro.plate('x',P) as p_tensor: for p in p_tensor: pyro.sample(f"x_{p}",...)?

@fritzo In the case of the recommended version, I should just pass the batch to the plate, and then continue using it without further modifications? As in:

def model(dataset_total_length, x_data, y_data):
    with pyro.plate('map', dataset_total_length, subsample=x_data):
        pyro.sample('observations', Normal(prediction_mean, scale), obs=y_data)

Thanks

@lqrz Yes, you can continue using x_data without further modifications. The subsample= arg to pyro.plate overrides any subsampling behavior with the user-provided subsample.

1 Like