 # SSVAE image segmentation

Hello!

I implemented the SSVAE with help from the nice tutorial, and would like to extend it for image segmentation. So in practice I need to have a `y` for each pixel, resulting in a `NxCxHxW` tensor instead of a `N x num_classes` tensor where `N` is batch size, `C` channels `H` image height and `W` image width. However, I don’t need `z` to be a four dimensional tensor, which makes using plate notation a bit tricky: By referring to the `-4` dimension as the batch dimension the dimensionality of `z` is expanded to fit (at least I assume this is what happening under the hood). Please see the model and guide code below and the comments for a more concrete problem description. My question in this: How should I treat the different dimensions of `z` and `y` when `y` is a segmentation mask?

Thanks for the great library and for your assistance!

/Sebastian

``````  def guide(self, x: Tensor, y: Optional[Tensor] = None) -> None:
N, C, W, H = x.shape  # same as y.shape
with pyro.plate("N", size=N, dim=-3):
with pyro.plate("pixel-c", size=C, dim=-2):
with pyro.plate("pixel-y", size=W, dim=-1):
if y is None:
pi = self.encode_y(x)
p_y = Bernoulli(pi).to_event(3)
y = pyro.sample("y", p_y)

# both loc and scale are N x self.z_dim
loc, scale = self.encode_z(x, y)
p_z = Normal(loc, scale).to_event(1)

# sampling results in Nx1xN x self.z_dim
# but we want it to be N x self.z_dim
z = pyro.sample("z", p_z)
# assert False, z.shape

def model(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor:
pyro.module("ssvae", self)
N, C, W, H = y.shape
# this plate doesn't make sense for z if we think of it as 2D
with pyro.plate("N", size=N, dim=-3):
# Sample latent variable
z_dim = torch.Size((N, self.z_dim))
p_z = Normal(x.new_zeros(z_dim), 1.0).to_event(1)
# sampling results in Nx1xN x self.z_dim
# but we want it to be N x self.z_dim
z = pyro.sample("z", p_z)
with pyro.plate("pixel-c", size=C, dim=-2):
with pyro.plate("pixel-y", size=W, dim=-1):

# Sample class labels per pixel
y_size = torch.Size((N, C, W, H))
pi_y = 0.01 * x.new_ones(y_size)
p_y = Bernoulli(pi_y).to_event(3)
y_sample = pyro.sample("y", p_y, obs=y)

# Sample the observation
pi_x = self.decode_x(z, y_sample)
p_x = Bernoulli(pi_x).to_event(3)
return pyro.sample("x", p_x, obs=x)
``````

I think the simplest thing to do here would be to delete your `pixel-` plates and treat `C`, `W` and `H` as event dimensions, which would let you set `dim=-1` for `N`. The pixel plates aren’t really buying you anything, since you aren’t subsampling along them and you can’t get much ELBO variance reduction out of them in the unlabeled case since the observations `x` depend on all pixel labels `y` jointly.

Thank you for your response!

Great! I was trying that initially but found a post here on the forum suggesting the pixel plates so I changed direction. So the `dim` argument starts from the first independent dimension?

Yes, that’s right. See our tutorials on tensor shapes and enumeration for more background.

Thank you very much for you help! I know I’m not the first one to be pointed to the tensor shapes docs, but I know the community really appreciates you taking the time to answer these kinds of questions.

Thanks again, and have a nice weekend!

1 Like