- What tutorial are you running? Tensor Shape in pyro 0.2
- What version of Pyro are you using? 0.2.1+649e7a85
- Please link or paste relevant code, and steps to reproduce.
I am running examples code from the aforementioned tutorial, specifically the section on writing parallelizable code, from which I have removed the assert statements and replaced them with printouts of the relevant shapes. Here is the code:
import numpy as np import torch import pyro import pyro.distributions as dist from pyro.distributions.util import broadcast_shape import torch.distributions.constraints as constraints from pyro import poutine from pyro.infer import SVI, TraceEnum_ELBO, config_enumerate, Trace_ELBO import sys t = torch width = 8 height = 10 sparse_pixels = torch.LongTensor([[3, 2], [3, 5], [3, 9], [7, 1]]) enumerated = None # set to either True or False below print("sparse_pixels: ", sparse_pixels.shape) # 4,2 # We'll use this helper to check our models are correct. def test_model(model, guide, loss): pyro.clear_param_store() loss.loss(model, guide) @poutine.broadcast def fun(observe): print("Fun: enumerated= ", enumerated) p_x = pyro.param("p_x", torch.tensor(0.1), constraint=constraints.unit_interval) p_y = pyro.param("p_y", torch.tensor(0.1), constraint=constraints.unit_interval) x_axis = pyro.iarange('x_axis', width, dim=-2) y_axis = pyro.iarange('y_axis', height, dim=-1) # Note that the shapes of these sites depend on whether Pyro is enumerating. with x_axis: x_active = pyro.sample("x_active", dist.Bernoulli(p_x)) print("x_active shape: ", x_active.shape) with y_axis: y_active = pyro.sample("y_active", dist.Bernoulli(p_y)) print("y_active shape: ", y_active.shape) # The first trick is to broadcast. This works with or without enumeration. p = 0.1 + 0.5 * x_active * y_active print("p shape: ", p.shape) dense_pixels = p.new_zeros(broadcast_shape(p.shape, (width, height))) # The second trick is to index using ellipsis slicing. # This allows Pyro to add arbitrary dimensions on the left. for x, y in sparse_pixels: dense_pixels[..., x, y] = 1 print("dense_pixels shape: ", dense_pixels.shape) with x_axis, y_axis: if observe: pyro.sample("pixels", dist.Bernoulli(p), obs=dense_pixels) def model4(): fun(observe=True) @config_enumerate(default="parallel") def guide4(): fun(observe=False) # Test without enumeration. # Shapes in guide and model are identical print("====================================") print("Enumerated == True, max_iarange_nesting not set") test_model(model4, guide4, Trace_ELBO()) #test_model(model4, guide4, TraceEnum_ELBO()) print("====================================") print("Enumerated == True, max_iarange_nesting == 2") test_model(model4, guide4, Trace_ELBO(max_iarange_nesting=2)) # not affected by max_iarange_nested #test_model(model4, guide4, TraceEnum_ELBO(max_iarange_nesting=2))
I notice a strange behavior, which the tuturial does not address.
The shape of various variables for Trace_ELBO is independent of the values of max_iarange_nesting, and the same is true when using TraceEnum_ELBO. The tutorial uses both routines. What is the effect of the argument of max_iarange_nesting? Could I see an example where the effect of max_iarange_nesting is shown without chaning the ELBO class used to demonstrate the effect?
If max_iarange_nesting has no effect, then why not simply discuss using Trace_ELBO or TraceEnum_ELBO without discussing max_iarange_nesting.