Hello!
I implemented the SSVAE with help from the nice tutorial, and would like to extend it for image segmentation. So in practice I need to have a y
for each pixel, resulting in a NxCxHxW
tensor instead of a N x num_classes
tensor where N
is batch size, C
channels H
image height and W
image width. However, I don’t need z
to be a four dimensional tensor, which makes using plate notation a bit tricky: By referring to the -4
dimension as the batch dimension the dimensionality of z
is expanded to fit (at least I assume this is what happening under the hood). Please see the model and guide code below and the comments for a more concrete problem description. My question in this: How should I treat the different dimensions of z
and y
when y
is a segmentation mask?
Thanks for the great library and for your assistance!
/Sebastian
def guide(self, x: Tensor, y: Optional[Tensor] = None) -> None:
N, C, W, H = x.shape # same as y.shape
with pyro.plate("N", size=N, dim=-3):
with pyro.plate("pixel-c", size=C, dim=-2):
with pyro.plate("pixel-y", size=W, dim=-1):
if y is None:
pi = self.encode_y(x)
p_y = Bernoulli(pi).to_event(3)
y = pyro.sample("y", p_y)
# both loc and scale are N x self.z_dim
loc, scale = self.encode_z(x, y)
p_z = Normal(loc, scale).to_event(1)
# sampling results in Nx1xN x self.z_dim
# but we want it to be N x self.z_dim
z = pyro.sample("z", p_z)
# assert False, z.shape
def model(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor:
pyro.module("ssvae", self)
N, C, W, H = y.shape
# this plate doesn't make sense for z if we think of it as 2D
with pyro.plate("N", size=N, dim=-3):
# Sample latent variable
z_dim = torch.Size((N, self.z_dim))
p_z = Normal(x.new_zeros(z_dim), 1.0).to_event(1)
# sampling results in Nx1xN x self.z_dim
# but we want it to be N x self.z_dim
z = pyro.sample("z", p_z)
with pyro.plate("pixel-c", size=C, dim=-2):
with pyro.plate("pixel-y", size=W, dim=-1):
# Sample class labels per pixel
y_size = torch.Size((N, C, W, H))
pi_y = 0.01 * x.new_ones(y_size)
p_y = Bernoulli(pi_y).to_event(3)
y_sample = pyro.sample("y", p_y, obs=y)
# Sample the observation
pi_x = self.decode_x(z, y_sample)
p_x = Bernoulli(pi_x).to_event(3)
return pyro.sample("x", p_x, obs=x)