- What tutorial are you running? Tensor Shape in pyro 0.2
- What version of Pyro are you using? 0.2.1+649e7a85
- Please link or paste relevant code, and steps to reproduce.
I am running examples code from the aforementioned tutorial, specifically the section on writing parallelizable code, from which I have removed the assert statements and replaced them with printouts of the relevant shapes. Here is the code:
import numpy as np
import torch
import pyro
import pyro.distributions as dist
from pyro.distributions.util import broadcast_shape
import torch.distributions.constraints as constraints
from pyro import poutine
from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, Trace_ELBO
import sys
t = torch
width = 8
height = 10
sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]])
enumerated = None # set to either True or False below
print("sparse_pixels: ", sparse_pixels.shape) # 4,2
# We'll use this helper to check our models are correct.
def test_model(model, guide, loss):
pyro.clear_param_store()
loss.loss(model, guide)
@poutine.broadcast
def fun(observe):
print("Fun: enumerated= ", enumerated)
p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval)
p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval)
x_axis = pyro.iarange('x_axis', width, dim=-2)
y_axis = pyro.iarange('y_axis', height, dim=-1)
# Note that the shapes of these sites depend on whether Pyro is enumerating.
with x_axis:
x_active = pyro.sample("x_active", dist.Bernoulli(p_x))
print("x_active shape: ", x_active.shape)
with y_axis:
y_active = pyro.sample("y_active", dist.Bernoulli(p_y))
print("y_active shape: ", y_active.shape)
# The first trick is to broadcast. This works with or without enumeration.
p = 0.1 + 0.5 * x_active * y_active
print("p shape: ", p.shape)
dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height)))
# The second trick is to index using ellipsis slicing.
# This allows Pyro to add arbitrary dimensions on the left.
for x, y in sparse_pixels:
dense_pixels[..., x, y] = 1
print("dense_pixels shape: ", dense_pixels.shape)
with x_axis, y_axis:
if observe:
pyro.sample("pixels", dist.Bernoulli(p), obs=dense_pixels)
def model4():
fun(observe=True)
@config_enumerate(default="parallel")
def guide4():
fun(observe=False)
# Test without enumeration.
# Shapes in guide and model are identical
print("====================================")
print("Enumerated == True, max_iarange_nesting not set")
test_model(model4, guide4, Trace_ELBO())
#test_model(model4, guide4, TraceEnum_ELBO())
print("====================================")
print("Enumerated == True, max_iarange_nesting == 2")
test_model(model4, guide4, Trace_ELBO(max_iarange_nesting=2)) # not affected by max_iarange_nested
#test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))
I notice a strange behavior, which the tuturial does not address.
The shape of various variables for Trace_ELBO is independent of the values of max_iarange_nesting, and the same is true when using TraceEnum_ELBO. The tutorial uses both routines. What is the effect of the argument of max_iarange_nesting? Could I see an example where the effect of max_iarange_nesting is shown without chaning the ELBO class used to demonstrate the effect?
If max_iarange_nesting has no effect, then why not simply discuss using Trace_ELBO or TraceEnum_ELBO without discussing max_iarange_nesting.
Thanks,
Gordon