Writing parallelizable code

  • What tutorial are you running? Tensor Shape in pyro 0.2
  • What version of Pyro are you using? 0.2.1+649e7a85
  • Please link or paste relevant code, and steps to reproduce.

I am running examples code from the aforementioned tutorial, specifically the section on writing parallelizable code, from which I have removed the assert statements and replaced them with printouts of the relevant shapes. Here is the code:

import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions.util import broadcast_shape
import torch.distributions.constraints as constraints
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, Trace_ELBO
import sys
t = torch



width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None  # set to either True or False below
print("sparse_pixels: ", sparse_pixels.shape)  # 4,2

# We'll use this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)


@poutine.broadcast
def fun(observe):
    print("Fun: enumerated= ", enumerated)
    p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
    p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
    x_axis = pyro.iarange('x_axis', width, dim=-2)
    y_axis = pyro.iarange('y_axis', height, dim=-1)

    # Note that the shapes of these sites depend on whether Pyro is enumerating.
    with x_axis:
        x_active = pyro.sample("x_active", dist.Bernoulli(p_x))
        print("x_active shape: ", x_active.shape)
    with y_axis:
        y_active = pyro.sample("y_active", dist.Bernoulli(p_y))
        print("y_active shape: ", y_active.shape)

    # The first trick is to broadcast. This works with or without enumeration.
    p = 0.1 + 0.5 * x_active * y_active
    print("p shape: ", p.shape)
    dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))

    # The second trick is to index using ellipsis slicing.
    # This allows Pyro to add arbitrary dimensions on the left.
    for x, y in sparse_pixels:
        dense_pixels[..., x, y] = 1
        print("dense_pixels shape: ", dense_pixels.shape)

    with x_axis, y_axis:
        if observe:
            pyro.sample("pixels", dist.Bernoulli(p), obs=dense_pixels)

def model4():
    fun(observe=True)

@config_enumerate(default="parallel")
def guide4():
    fun(observe=False)

# Test without enumeration.
# Shapes in guide and model are identical
print("====================================")
print("Enumerated == True, max_iarange_nesting not set")
test_model(model4, guide4, Trace_ELBO())
#test_model(model4, guide4, TraceEnum_ELBO())
print("====================================")
print("Enumerated == True, max_iarange_nesting == 2")
test_model(model4, guide4, Trace_ELBO(max_iarange_nesting=2))  # not affected by max_iarange_nested
#test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))

I notice a strange behavior, which the tuturial does not address.
The shape of various variables for Trace_ELBO is independent of the values of max_iarange_nesting, and the same is true when using TraceEnum_ELBO. The tutorial uses both routines. What is the effect of the argument of max_iarange_nesting? Could I see an example where the effect of max_iarange_nesting is shown without chaning the ELBO class used to demonstrate the effect?

If max_iarange_nesting has no effect, then why not simply discuss using Trace_ELBO or TraceEnum_ELBO without discussing max_iarange_nesting.

Thanks,

Gordon

On the dev branch, which you seem to be using, we’ve added a feature that attempts to guess max_iarange_nesting (recently renamed to max_plate_nesting) when it’s not provided because, as you’ve discovered, it can be rather confusing. In your example, max_iarange_nesting is indeed 2 (because of the line with x_axis, y_axis: ... in fun), so the two examples will behave the same with TraceEnum_ELBO.

Also, because Trace_ELBO does not perform enumeration, it does not use max_iarange_nesting except for some optional error checking logic, so the value will have no effect there.

Hi,

I am finally understanding enumeration (which appears to be simple once I change my my thinking process!). I appreciate the help. One question: does the @poutine.broadcast decorator work as previously? I thought I read somewhere that when using iarange, that broadcast happened whether @poutine.broadcast was used or not (at least in the developer version of pyro). Specifically, in the code below, I get the same result whether @poutine.broadcast is used or not.

import torch
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate
import sys

# We'll ue this helper to check our models are correct.
def test_model(model, guide, loss):
    pyro.clear_param_store()
    loss.loss(model, guide)

@config_enumerate(default="parallel")
@poutine.broadcast
def model3():
    print("enter model3")
    p = pyro.param("p", torch.arange(6.) / 6)
    locs = pyro.param("locs", torch.tensor([-1., 1.]))

    with pyro.iarange("c_iarange", 4):
        with pyro.iarange("d_iarange", 5):
            e = pyro.sample("e", dist.Normal(0., 1.))

    #                   enumerated|batch|event dims
    print("e.shape= ", e.shape)

guide = model3
test_model(model3, guide, TraceEnum_ELBO()) 

Another question: supposed I define a sample:

s = pyro.sample(....)

how can I retrieve the batch_shape and the event_shape? From what I can tell, only the distribution can access these two quantities. Thanks.

Gordon

does the @poutine.broadcast decorator work as previously? I thought I read somewhere that when using iarange, that broadcast happened whether @poutine.broadcast was used or not (at least in the developer version of pyro).

You’re correct - on dev we’ve now moved that functionality inside of pyro.plate so that the extra boilerplate of the @broadcast decorator is no longer necessary.

how can I retrieve the batch_shape and the event_shape?

Samples are just regular torch.Tensors - they don’t have separate batch_shape and event_shapes on their own, as discussed in the Distribution shapes section of the tensor shape tutorial. Different inference algorithms may produce samples with different shapes and modify the shapes of distributions at sample sites. You can use pyro.poutine.trace on a model to get a data structure containing all of the distributions and samples that appear in an execution of the model:

def model(x):
    ...

trace = pyro.poutine.trace(model).get_trace(x)
for node in trace.nodes.values():  # an OrderedDict with one entry per sample site
    print(node["name"])  # the name of a sample site
    print(node["fn"])  # the distribution at that sample site
    print(node["value"])  # the value returned from pyro.sample at that site