Hi @learningmath,
First I think your code snippets may contain typos, or you may be using an old Python version (funsor requires Python 3.6+). Here is some code that works on my machine:
pt_x = torch.randn(4, 3, 2)
x = funsor.Tensor(pt_x, OrderedDict(foo=Bint[4], bar=Bint[3]))
assert x.output.shape == (2,)
assert x.shape == (2,) # just an alias for .output.shape
assert x.inputs["foo"].size == 4
assert x.inputs["bar"].size == 3
assert x.data is pt_x
assert x("foo") is x
assert x("foo", "bar") is x
assert x(foo="foo") is x
assert x(bar="bar") is x
Now to address your question
how exactly do you use named indexing to index into the batch dimensions or broadcast
Note that x(*args)
substitutes into x positionally, and x(**kwargs)
substitutes into x by name. In funsor code we prefer the named version x(**kwargs)
because it is safer and invariant to transposing.
Your first substitution was
x("foo") == x(Variable("foo", Bint[4]))
== x(foo="foo")
== x(foo=Variable("foo", Bint[4]))
and is equivalent to a renaming. But since you’re renaming the “foo” input to “foo”, you end up with the same old x. You could have alternately renamed to “baz”, which would have changed the name but preserved the underlying torch.Tensor
:
assert x(foo="baz") is not x
assert x(foo="baz").data is x.data
The PyTorch equivalent to renaming might be substituting a reshaped torch.arange
to accomplish a permutation, i.e. to change the position of dimensions, e.g.
i = torch.arange(4).reshape(1, 4, 1)
j = torch.arange(3).reshape(3, 1, 1)
k = torch.arange(2).reshape(1, 1, 2)
pt_y = pt_x[i, j, k]
assert pt_y.shape == (3, 4, 2)
assert (pt_y == pt_x.transpose(0, 1)).all()
Now to index into a funsor.Tensor, you’ll substitute an actual number, e.g.
x1 = x(foo=1)
assert set(x1.inputs) == {"bar"}
assert x1.data.shape == (3, 2)
x2 = x(bar=2)
assert set(x2.inputs) == {"foo"}
assert x2.data.shape == (4, 2)
x12 = x(foo=1, bar=2)
assert not x12.inputs
assert x12.data.shape == (2,)
To slice you’ll need to create a slice object. Funsor does have a symbolic funsor.Slice
funsor, or you can use advanced indexing via a funsor.Tensor
:
y1 = x(foo=funsor.Slice("fo", 0, 2))
assert set(y1.inputs) == {"fo", "bar"}
assert x.inputs["foo"] == 4
assert y1.inputs["fo"].size == 2
y2 = x(foo=funsor.Tensor(torch.arange(2), OrderedDict(fo=Bint[4]), 4))
assert y1.inputs == y2.inputs
assert y1.output == y1.output
assert (y1 == y2).data.all()
You should similarly be able to use “advanced indexing” by substituting other Funsors into x
.