Funsor named indexing semantics

I had a question regarding funsor Tensor named indexing. Names for matching free variables with dimensions for function substitution makes sense, but tensor positional indexing seems to be orthogonal from named based indexing.
Consider the following example:

pt_x = torch.randn(4, 3, 2)
x = funsor.Tensor(pt_x, OrderedDict(foo=Bint[4], bar=Bint[3]))

This instantiates a funsor of batch shape (4, 3), event shape (2,) and output shape ().

pt_x.shape == (4, 3, 2)
x.shape == (3, 2) # just the event dims
x[0].data.shape == (4, 3) # == pt_x[:, 0].shape, only indexes
                          # event dimensions
x('foo').shape == x.shape  # != x[0].shape
x('bar')  # fails without a readable error message
x == x('foo')  # x('foo') is not a slice; the entire tensor is returned

In spite of this, funsor tensors still follow Pytorch broadcasting semantics. So how exactly do you use named indexing to index into the batch dimensions or broadcast? Or is this strictly for aligning/broadcasting variables by named dimension during substitution?

Hi @learningmath,
First I think your code snippets may contain typos, or you may be using an old Python version (funsor requires Python 3.6+). Here is some code that works on my machine:

pt_x = torch.randn(4, 3, 2)
x = funsor.Tensor(pt_x, OrderedDict(foo=Bint[4], bar=Bint[3]))
assert x.output.shape == (2,)
assert x.shape == (2,)  # just an alias for .output.shape
assert x.inputs["foo"].size == 4
assert x.inputs["bar"].size == 3
assert is pt_x
assert x("foo") is x
assert x("foo", "bar") is x
assert x(foo="foo") is x
assert x(bar="bar") is x

Now to address your question

how exactly do you use named indexing to index into the batch dimensions or broadcast

Note that x(*args) substitutes into x positionally, and x(**kwargs) substitutes into x by name. In funsor code we prefer the named version x(**kwargs) because it is safer and invariant to transposing.

Your first substitution was

x("foo") == x(Variable("foo", Bint[4]))
         == x(foo="foo")
         == x(foo=Variable("foo", Bint[4]))

and is equivalent to a renaming. But since you’re renaming the “foo” input to “foo”, you end up with the same old x. You could have alternately renamed to “baz”, which would have changed the name but preserved the underlying torch.Tensor:

assert x(foo="baz") is not x
assert x(foo="baz").data is

The PyTorch equivalent to renaming might be substituting a reshaped torch.arange to accomplish a permutation, i.e. to change the position of dimensions, e.g.

i = torch.arange(4).reshape(1, 4, 1)
j = torch.arange(3).reshape(3, 1, 1)
k = torch.arange(2).reshape(1, 1, 2)
pt_y = pt_x[i, j, k]
assert pt_y.shape == (3, 4, 2)
assert (pt_y == pt_x.transpose(0, 1)).all()

Now to index into a funsor.Tensor, you’ll substitute an actual number, e.g.

x1 = x(foo=1)
assert set(x1.inputs) == {"bar"}
assert == (3, 2)
x2 = x(bar=2)
assert set(x2.inputs) == {"foo"}
assert == (4, 2)
x12 = x(foo=1, bar=2)
assert not x12.inputs
assert == (2,)

To slice you’ll need to create a slice object. Funsor does have a symbolic funsor.Slice funsor, or you can use advanced indexing via a funsor.Tensor:

y1 = x(foo=funsor.Slice("fo", 0, 2))
assert set(y1.inputs) == {"fo", "bar"}
assert x.inputs["foo"] == 4
assert y1.inputs["fo"].size == 2

y2 = x(foo=funsor.Tensor(torch.arange(2), OrderedDict(fo=Bint[4]), 4))
assert y1.inputs == y2.inputs
assert y1.output == y1.output
assert (y1 == y2).data.all()

You should similarly be able to use “advanced indexing” by substituting other Funsors into x.

1 Like