I’ve been reading through minipyro since it seemed recommended as a gentler introduction to the inner workings of pyro than actually reading pyro’s source.
plate is defined like this:
class PlateMessenger(Messenger): def __init__(self, fn, size, dim): assert dim < 0 self.size = size self.dim = dim super(PlateMessenger, self).__init__(fn) def process_message(self, msg): if msg["type"] == "sample": batch_shape = msg["fn"].batch_shape if len(batch_shape) < -self.dim or batch_shape[self.dim] != self.size: batch_shape =  * (-self.dim - len(batch_shape)) + list(batch_shape) batch_shape[self.dim] = self.size msg["fn"] = msg["fn"].expand(torch.Size(batch_shape)) def __iter__(self): return range(self.size) ... # boilerplate to match the syntax of actual pyro.plate: def plate(name, size, dim=None): if dim is None: raise NotImplementedError("minipyro.plate requires a dim arg") return PlateMessenger(fn=None, size=size, dim=dim)
When I try to use it as an iterator like this:
for i in plate('hi', 5, dim=-1): print(i)
I get this:
TypeError: iter() returned non-iterator of type 'range'
Am I confused about something? Is this just an oversight? Is plate not meant to work as an iterator in minipyro?
Maybe it should have been this?:
def __iter__(self): yield from range(self.size)