SSVAE image segmentation

Hello!

I implemented the SSVAE with help from the nice tutorial, and would like to extend it for image segmentation. So in practice I need to have a y for each pixel, resulting in a NxCxHxW tensor instead of a N x num_classes tensor where N is batch size, C channels H image height and W image width. However, I don’t need z to be a four dimensional tensor, which makes using plate notation a bit tricky: By referring to the -4 dimension as the batch dimension the dimensionality of z is expanded to fit (at least I assume this is what happening under the hood). Please see the model and guide code below and the comments for a more concrete problem description. My question in this: How should I treat the different dimensions of z and y when y is a segmentation mask?

Thanks for the great library and for your assistance!

/Sebastian

  def guide(self, x: Tensor, y: Optional[Tensor] = None) -> None:
        N, C, W, H = x.shape  # same as y.shape
        with pyro.plate("N", size=N, dim=-3):
            with pyro.plate("pixel-c", size=C, dim=-2):
                with pyro.plate("pixel-y", size=W, dim=-1):
                    if y is None:
                        pi = self.encode_y(x)
                        p_y = Bernoulli(pi).to_event(3)
                        y = pyro.sample("y", p_y)

            # both loc and scale are N x self.z_dim
            loc, scale = self.encode_z(x, y)
            p_z = Normal(loc, scale).to_event(1)

            # sampling results in Nx1xN x self.z_dim
            # but we want it to be N x self.z_dim
            z = pyro.sample("z", p_z)
            # assert False, z.shape

    def model(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor:
        pyro.module("ssvae", self)
        N, C, W, H = y.shape
        # this plate doesn't make sense for z if we think of it as 2D
        with pyro.plate("N", size=N, dim=-3):
            # Sample latent variable
            z_dim = torch.Size((N, self.z_dim))
            p_z = Normal(x.new_zeros(z_dim), 1.0).to_event(1)
            # sampling results in Nx1xN x self.z_dim
            # but we want it to be N x self.z_dim
            z = pyro.sample("z", p_z)
            with pyro.plate("pixel-c", size=C, dim=-2):
                with pyro.plate("pixel-y", size=W, dim=-1):

                    # Sample class labels per pixel
                    y_size = torch.Size((N, C, W, H))
                    pi_y = 0.01 * x.new_ones(y_size)
                    p_y = Bernoulli(pi_y).to_event(3)
                    y_sample = pyro.sample("y", p_y, obs=y)

            # Sample the observation
            pi_x = self.decode_x(z, y_sample)
            p_x = Bernoulli(pi_x).to_event(3)
            return pyro.sample("x", p_x, obs=x)

I think the simplest thing to do here would be to delete your pixel- plates and treat C, W and H as event dimensions, which would let you set dim=-1 for N. The pixel plates aren’t really buying you anything, since you aren’t subsampling along them and you can’t get much ELBO variance reduction out of them in the unlabeled case since the observations x depend on all pixel labels y jointly.

Thank you for your response!

Great! I was trying that initially but found a post here on the forum suggesting the pixel plates so I changed direction. So the dim argument starts from the first independent dimension?

Yes, that’s right. See our tutorials on tensor shapes and enumeration for more background.

Thank you very much for you help! I know I’m not the first one to be pointed to the tensor shapes docs, but I know the community really appreciates you taking the time to answer these kinds of questions.

Thanks again, and have a nice weekend!

1 Like