Hi all,
I’ve found that introducing a sampling site with zero-dimensional event has an effect on my network.
Is this expected behaviour?
My expectation would be that introducing this sampling site (where there exist other downstream dependencies) would yield the same results as a network without this sampling site featuring at all. Afterall, how can a sample with no dimension affect anything?
e.g. in pseudocode
# as part of some larger network
def net_zero(x):
a,b, = f(x) # where a and b are scalars
z_dist = dist.Normal(a,b).expand([0]).to_event(1)
return z_dist
def net_nonzero(x):
a,b, = f(x) # where a and b are scalars
z_dist = dist.Normal(a,b).expand([4]).to_event(1)
return z_dist
...
with pyro.plate("data", size, subsample=x):
z0 = pyro.sample("z0", net_zero(x)) # this sample has zero dims
z = pyro.sample("z", net_nonzero(x)) # this sample has non_zero dims
y = pyro.sample("y", y_net(torch.cat(z, zo)), obs=y .... )
Such that
z0.shape = torch.Size([batch_dim, 0])
z0 = tensor([])
If I remove zo, and just feed z into y_net, I get a different result (in terms of learned parameters and output), despite zo having zero dimensional event shape. I’m optimizing with the standard SVI ELBO, and just commenting out the z0 parts.
If I’m not missing something obvious, I’m happy to provide minimal code to reproduce, but I’m relatively new to pyro so want to check I’m not just being daft, first.
(I’m using python 3.7.6, torch 1.6.0 and pyro 1.4.0)
Many thanks!
M