Extension of LDA tutorial for uncorrelated topic betas using MVN

Hi there,
I’m playing with the ProdLDA tutorial and I’m having some real trouble trying to implement an extension (pyro 1.7.0, torch 1.10.0). The goal is to try to put a prior of zero covariance over the beta weights in the model, so I was thinking this could be done by sampling the weights from a multivariate normal distribution with a diagonal covariance matrix.

I thought that adding

self.beta = PyroModule[nn.Linear](num_topics, vocab_size)
self.beta.weight=PyroSample(pyro.distributions.MultivariateNormal(
                                torch.zeros((num_topics)), 
                                covariance_matrix=torch.eye(num_topics))
                            .expand((vocab_size,)).to_event(1))

into the Decoder class would accomplish this. I also tried moving the weight assignment into the model under the topics plate would help, but I either can’t get the right shape, or alternatively I get this error with the code as shown:

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D.

Thanks for your thoughts!
mt

Full model below:

from pyro.nn import PyroSample
from pyro.nn import PyroModule

    class Encoder(nn.Module):
        # Base class for the encoder net, used in the guide
        def __init__(self, vocab_size, num_topics, hidden, dropout):
            super().__init__()
            self.drop = nn.Dropout(dropout)  # to avoid component collapse
            self.fc1 = nn.Linear(vocab_size, hidden)
            self.fc2 = nn.Linear(hidden, hidden)
            self.fcmu = nn.Linear(hidden, num_topics)
            self.fclv = nn.Linear(hidden, num_topics)
            # NB: here we set `affine=False` to reduce the number of learning parameters
            # See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
            # for the effect of this flag in BatchNorm1d
            self.bnmu = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse
            self.bnlv = nn.BatchNorm1d(num_topics, affine=False)  # to avoid component collapse

        def forward(self, inputs):
            h = F.softplus(self.fc1(inputs))
            h = F.softplus(self.fc2(h))
            h = self.drop(h)
            # μ and Σ are the outputs
            logtheta_loc = self.bnmu(self.fcmu(h))
            logtheta_logvar = self.bnlv(self.fclv(h))
            logtheta_scale = (0.5 * logtheta_logvar).exp()  # Enforces positivity
            return logtheta_loc, logtheta_scale

    """
    class LinearMVN(nn.Linear, PyroModule):
        def __init__(self, in_size, out_size,bias=False):
            super().__init__(in_size, out_size,bias)
            self.weight = PyroSample(lambda self: pyro.distributions.MultivariateNormal(torch.zeros((in_size)), covariance_matrix=torch.eye(in_size)).expand((out_size,)).to_event(1))
            
        def forward(self, input_):
            return(self.bias + input_ @ self.weight)"""

    class Decoder(nn.Module):
        # Base class for the decoder net, used in the model
        def __init__(self, vocab_size, num_topics, dropout):
            super().__init__()
            #Replace beta with MVN
            #self.beta = nn.Linear(num_topics, vocab_size, bias=False)
            #self.beta = LinearMVN(num_topics, vocab_size, bias=False)
            self.beta = PyroModule[nn.Linear](num_topics, vocab_size)
            self.beta.weight=PyroSample(pyro.distributions.MultivariateNormal(
                                            torch.zeros((num_topics)), 
                                            covariance_matrix=torch.eye(num_topics))
                                        .expand((vocab_size,)).to_event(1))

            self.bn = nn.BatchNorm1d(vocab_size, affine=False)
            self.drop = nn.Dropout(dropout)

        def forward(self, inputs):
            inputs = self.drop(inputs)
            # the output is σ(βθ)
            return F.softmax(self.bn(self.beta(inputs)), dim=1)


    class ProdLDA(nn.Module):
        def __init__(self, vocab_size, num_topics, hidden, dropout):
            super().__init__()
            self.vocab_size = vocab_size
            self.num_topics = num_topics
            self.encoder = Encoder(vocab_size, num_topics, hidden, dropout)
            self.decoder = Decoder(vocab_size, num_topics, dropout)

        def model(self, docs):
            pyro.module("decoder", self.decoder)
            with pyro.plate("documents", docs.shape[0]):
                # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution
                logtheta_loc = docs.new_zeros((docs.shape[0], self.num_topics))
                logtheta_scale = docs.new_ones((docs.shape[0], self.num_topics))
                logtheta = pyro.sample(
                    "logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
                theta = F.softmax(logtheta, -1)
                # conditional distribution of 𝑤𝑛 is defined as
                # 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))
                #This gives the same error
                #self.decoder.beta.weight=PyroSample(pyro.distributions.MultivariateNormal(
                #                                        torch.zeros((self.num_topics)), 
                #                                        covariance_matrix=torch.eye(self.num_topics))
                #                                    .expand((self.vocab_size,)))
                #print(self.decoder.beta.weight.shape)
                count_param = self.decoder(theta)
                # Currently, PyTorch Multinomial requires `total_count` to be homogeneous.
                # Because the numbers of words across documents can vary,
                # we will use the maximum count accross documents here.
                # This does not affect the result because Multinomial.log_prob does
                # not require `total_count` to evaluate the log probability.
                total_count = int(docs.sum(-1).max())
                pyro.sample(
                    'obs',
                    dist.Multinomial(total_count, count_param),
                    obs=docs
                )

        def guide(self, docs):
            pyro.module("encoder", self.encoder)
            with pyro.plate("documents", docs.shape[0]):
                # Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution,
                # where μ and Σ are the encoder network outputs
                logtheta_loc, logtheta_scale = self.encoder(docs)
                logtheta = pyro.sample(
                    "logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
        
        def beta(self):
            # beta matrix elements are the weights of the FC layer on the decoder
            return self.decoder.beta.weight.cpu().detach().T

Hi @mtvector, torch.nn.Linear expects its weight tensor to be 2D. You may need to replace that module with a custom one that allows broadcasted weights.