I’ve been reading through minipyro since it seemed recommended as a gentler introduction to the inner workings of pyro than actually reading pyro’s source.
plate is defined like this:
class PlateMessenger(Messenger):
def __init__(self, fn, size, dim):
assert dim < 0
self.size = size
self.dim = dim
super(PlateMessenger, self).__init__(fn)
def process_message(self, msg):
if msg["type"] == "sample":
batch_shape = msg["fn"].batch_shape
if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size:
batch_shape = [1] * (-self.dim - len(batch_shape)) + list(batch_shape)
batch_shape[self.dim] = self.size
msg["fn"] = msg["fn"].expand(torch.Size(batch_shape))
def __iter__(self):
return range(self.size)
...
# boilerplate to match the syntax of actual pyro.plate:
def plate(name, size, dim=None):
if dim is None:
raise NotImplementedError("minipyro.plate requires a dim arg")
return PlateMessenger(fn=None, size=size, dim=dim)
When I try to use it as an iterator like this:
for i in plate('hi', 5, dim=-1):
print(i)
I get this:
TypeError: iter() returned non-iterator of type 'range'
Am I confused about something? Is this just an oversight? Is plate not meant to work as an iterator in minipyro?
Maybe it should have been this?:
def __iter__(self):
yield from range(self.size)