Hi there,
I’m playing with the ProdLDA tutorial and I’m having some real trouble trying to implement an extension (pyro 1.7.0, torch 1.10.0). The goal is to try to put a prior of zero covariance over the beta weights in the model, so I was thinking this could be done by sampling the weights from a multivariate normal distribution with a diagonal covariance matrix.
I thought that adding
self.beta = PyroModule[nn.Linear](num_topics, vocab_size)
self.beta.weight=PyroSample(pyro.distributions.MultivariateNormal(
torch.zeros((num_topics)),
covariance_matrix=torch.eye(num_topics))
.expand((vocab_size,)).to_event(1))
into the Decoder class would accomplish this. I also tried moving the weight assignment into the model under the topics plate would help, but I either can’t get the right shape, or alternatively I get this error with the code as shown:
RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D
.
Thanks for your thoughts!
mt
Full model below:
from pyro.nn import PyroSample
from pyro.nn import PyroModule
class Encoder(nn.Module):
# Base class for the encoder net, used in the guide
def __init__(self, vocab_size, num_topics, hidden, dropout):
super().__init__()
self.drop = nn.Dropout(dropout) # to avoid component collapse
self.fc1 = nn.Linear(vocab_size, hidden)
self.fc2 = nn.Linear(hidden, hidden)
self.fcmu = nn.Linear(hidden, num_topics)
self.fclv = nn.Linear(hidden, num_topics)
# NB: here we set `affine=False` to reduce the number of learning parameters
# See https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
# for the effect of this flag in BatchNorm1d
self.bnmu = nn.BatchNorm1d(num_topics, affine=False) # to avoid component collapse
self.bnlv = nn.BatchNorm1d(num_topics, affine=False) # to avoid component collapse
def forward(self, inputs):
h = F.softplus(self.fc1(inputs))
h = F.softplus(self.fc2(h))
h = self.drop(h)
# μ and Σ are the outputs
logtheta_loc = self.bnmu(self.fcmu(h))
logtheta_logvar = self.bnlv(self.fclv(h))
logtheta_scale = (0.5 * logtheta_logvar).exp() # Enforces positivity
return logtheta_loc, logtheta_scale
"""
class LinearMVN(nn.Linear, PyroModule):
def __init__(self, in_size, out_size,bias=False):
super().__init__(in_size, out_size,bias)
self.weight = PyroSample(lambda self: pyro.distributions.MultivariateNormal(torch.zeros((in_size)), covariance_matrix=torch.eye(in_size)).expand((out_size,)).to_event(1))
def forward(self, input_):
return(self.bias + input_ @ self.weight)"""
class Decoder(nn.Module):
# Base class for the decoder net, used in the model
def __init__(self, vocab_size, num_topics, dropout):
super().__init__()
#Replace beta with MVN
#self.beta = nn.Linear(num_topics, vocab_size, bias=False)
#self.beta = LinearMVN(num_topics, vocab_size, bias=False)
self.beta = PyroModule[nn.Linear](num_topics, vocab_size)
self.beta.weight=PyroSample(pyro.distributions.MultivariateNormal(
torch.zeros((num_topics)),
covariance_matrix=torch.eye(num_topics))
.expand((vocab_size,)).to_event(1))
self.bn = nn.BatchNorm1d(vocab_size, affine=False)
self.drop = nn.Dropout(dropout)
def forward(self, inputs):
inputs = self.drop(inputs)
# the output is σ(βθ)
return F.softmax(self.bn(self.beta(inputs)), dim=1)
class ProdLDA(nn.Module):
def __init__(self, vocab_size, num_topics, hidden, dropout):
super().__init__()
self.vocab_size = vocab_size
self.num_topics = num_topics
self.encoder = Encoder(vocab_size, num_topics, hidden, dropout)
self.decoder = Decoder(vocab_size, num_topics, dropout)
def model(self, docs):
pyro.module("decoder", self.decoder)
with pyro.plate("documents", docs.shape[0]):
# Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution
logtheta_loc = docs.new_zeros((docs.shape[0], self.num_topics))
logtheta_scale = docs.new_ones((docs.shape[0], self.num_topics))
logtheta = pyro.sample(
"logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
theta = F.softmax(logtheta, -1)
# conditional distribution of 𝑤𝑛 is defined as
# 𝑤𝑛|𝛽,𝜃 ~ Categorical(𝜎(𝛽𝜃))
#This gives the same error
#self.decoder.beta.weight=PyroSample(pyro.distributions.MultivariateNormal(
# torch.zeros((self.num_topics)),
# covariance_matrix=torch.eye(self.num_topics))
# .expand((self.vocab_size,)))
#print(self.decoder.beta.weight.shape)
count_param = self.decoder(theta)
# Currently, PyTorch Multinomial requires `total_count` to be homogeneous.
# Because the numbers of words across documents can vary,
# we will use the maximum count accross documents here.
# This does not affect the result because Multinomial.log_prob does
# not require `total_count` to evaluate the log probability.
total_count = int(docs.sum(-1).max())
pyro.sample(
'obs',
dist.Multinomial(total_count, count_param),
obs=docs
)
def guide(self, docs):
pyro.module("encoder", self.encoder)
with pyro.plate("documents", docs.shape[0]):
# Dirichlet prior 𝑝(𝜃|𝛼) is replaced by a logistic-normal distribution,
# where μ and Σ are the encoder network outputs
logtheta_loc, logtheta_scale = self.encoder(docs)
logtheta = pyro.sample(
"logtheta", dist.Normal(logtheta_loc, logtheta_scale).to_event(1))
def beta(self):
# beta matrix elements are the weights of the FC layer on the decoder
return self.decoder.beta.weight.cpu().detach().T