Reasoning through sample_shape, batch_shape, and event_shape for IID

I have created n IID samples in a straightforward way with a 0-dimensional torch.Size([]) parameter in a distribution and using the “sample_shape” argument in the sample method on that distribution to get the torch.Size([n]) IID tensor I want.

When I tried passing “sample_shape” as a keyword argument to pyro.sample, it still generated the expected tensor. However, I got an error during inference (error: "unexpected keyword ‘sample_shape’ in logmass").

I think I’ve gotten around this using a parameter for the distribution with the desired data shape, and using pyro.iarange – my n independent samples have a “batch_shape” of n, and “event_shape” is 0. I think that is right, but even if it is, it is less intuitive than having pyro.sample take “sample_shape” as a legitimate argument, why doesn’t it?

Well sample_shape is a little ambiguous; Pyro needs to know what part of that shape is independent (part of batch_shape), and what part is dependent (part of event_shape). This information is used internally in many ways.

If you’re ok with the samples being treated as dependent, you can

x = pyro.sample("x", dist.Normal(0., 1.).expand([n]).independent(1))

If the samples are really independent, then yes Pyro really does require a pyro.iarange context (renamed to pyro.plate in dev and the upcoming 0.3 release).

with pyro.plate("data", n):
    x = pyro.sample("x", dist.Normal(0. 1.))

FWIW I find the pyro.plate contexts are more maintainable as models grow, since you can paste code into larger plates to batch in multiple ways. Early on Pyro did lots of configuration via kwargs to sample statements (scaling, masking, batch_size, …), but these thwarted modularity; instead we can now write

with pyro.scale(p), pyro.mask(m), pyro.plate("data", n):
    x = custom_x_sampler()

where custom_x_sampler() might be a function with multiple sample statements. By moving these effects from kwargs to context managers, we gain modularity.

plate as in plate notation? I like that.

But it seems that

with pyro.iarange("data", 10):
    x = pyro.sample("x", dist.Normal(0., 1.))

does not give me 10 IID standard normals?

Sorry, iarange and plate will give you multiple samples on the dev branch and in the upcoming 0.3 release, but in the current 0.2.1 release you’ll need to add broadcast for this to work:

# in pyro 0.2.1 release
with pyro.poutine.broadcast(), pyro.iarange("data", 10):
    x = pyro.sample("x", dist.Normal(0., 1.))

Note that the broadcast is added automatically in SVI and HMC. (Note that combining broadcast+iarange=plate resulted from a fruitful conversation with Jan-Willem at PROBPROG :slight_smile: )

1 Like