Hi @eb8680_2, sorry for bringing up this again.
I tried developing a model where, during train phase, I pass a vector of observations and a obs_mask but I got dissapointing results.
What I noticed is that when I include the argument obs_mask for some reason the training becomes way more difficult even if the obs_mask is just full of True values.
Changing
with pyro.plate("data",x.shape[0]):
obs = pyro.sample("obs",dist.Normal(mean_output,sigma_output).to_event(1),obs = y)
to
with pyro.plate("data",x.shape[0]):
obs = pyro.sample("obs",dist.Normal(mean_output,sigma_output).to_event(1),obs = y,obs_mask=obs_mask)
slows down training and gives worse results even with
obs_mask = torch.bernoulli(torch.ones(N_batch)) > 0 # bool tensor full of “True” values
Furthermore it looks like the loss is computed in a different way, I sometimes get strange loss values (loss oscillating between -5e21 to -5 e24 while before I had loss decreasing from ~+200 to -1300) and I have to pick lower learning rates to avoid divergence.
Any clue on why this happens?
Here’s my code, in case it might be useful to understand the problem
class BayesianRegression(PyroModule):
def __init__(self,in_features,out_features,latent_features,hidden_features = 4):
super(BayesianRegression, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.latent_features = latent_features
self.hiddden_features = hidden_features
self.linear_input_hidden = nn.Linear(in_features,hidden_features)
self.linear_hidden_latent = nn.Linear(hidden_features,latent_features)
self.linear_latent_hidden = nn.Linear(latent_features,hidden_features)
self.linear_hidden_hidden2 = nn.Linear(hidden_features,hidden_features*2)
self.linear_hidden2_hidden3 = nn.Linear(hidden_features*2,hidden_features)
self.linear_hidden_output = nn.Linear(hidden_features,out_features)
self.linear_input_sigma = nn.Linear(latent_features,out_features)
self.relu = nn.ReLU()
self.softplus = nn.Softplus()
def forward(self,x,y = None,latent_obs = None , obs_mask = None):
sigma_latent = pyro.sample("sigma_latent",dist.Uniform(0.,1.))
mean_latent_ = self.relu(self.linear_input_hidden(x))
mean_latent = self.linear_hidden_latent(mean_latent_)
with pyro.plate("latent_data",x.shape[0]):
latent = pyro.sample("latent_obs",dist.Normal(mean_latent,sigma_latent).to_event(1),obs = latent_obs)
sigma_output = self.softplus(self.linear_input_sigma(latent))
mean_output_ = self.relu(self.linear_latent_hidden(latent))
mean_output_ = self.relu(self.linear_hidden_hidden2(mean_output_))
mean_output_ = self.relu(self.linear_hidden2_hidden3(mean_output_))
mean_output = self.linear_hidden_output(mean_output_)
with pyro.plate("data",x.shape[0]):
obs = pyro.sample("obs",dist.Normal(mean_output,sigma_output).to_event(1),obs = y) #,obs_mask=obs_mask)
Basically I have a NN which from input computes some latent variables (self.linear_input_hidden, self.linear_hidden_latent).
The latent variables are also observed ( latent = pyro.sample(…,obs = latent_obs).
The latent pass through another two NNs, one for output mean (self.linear_latent_hidden … self.linear_hidden_output) and one for output sigma (self.linear_input_sigma).