I am trying to use the guide from the DMM example to generate latent representation for a data point. The latent representation is then used for some downstream tasks. We observed that the performance of those tasks varied a lot when the latent was generated one sample a time vs a batch of size greater than one.
To demonstrate this, I used the original DMM example with the polyphonic dataset and generated two sets of latents for the evaluation data, one with a batch_size=10, and one with data being passed one by one to the sample statement. I then did a L2 norm between these two representations using torch.dist and the distance is non-zero in the range of 700-800 after 100 epochs of training. To reduce the effect of random number generator state, I set the manual seed just before the call to the function to get the latents.
So, I added the following two functions:
def transform_batch(self, sequences, seq_lengths=None, mini_batch_size=10):
N_data = len(sequences)
N_mini_batches = int(N_data / mini_batch_size +
int(N_data % mini_batch_size > 0))
data_indices = np.arange(N_data)
z_batches = np.ndarray([0,np.max(seq_lengths),self.z_dim])
#TODO:Verify this for variable length sequences
for which_mini_batch in range(N_mini_batches):
mini_batch_start = (which_mini_batch * mini_batch_size)
mini_batch_end = np.min(
[(which_mini_batch + 1) * mini_batch_size, N_data])
mini_batch_indices = data_indices[mini_batch_start:mini_batch_end]
# grab a fully prepped mini-batch using the helper function in the data loader
mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths \
= poly.get_mini_batch(mini_batch_indices, sequences,
seq_lengths)
# compute the validation and test loss n_samples many times
z_temp = self.get_latents(mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lengths)
z_temp_fixed_length = np.zeros((mini_batch.shape[0], z_batches.shape[1], z_batches.shape[2]), dtype=float)
if (len(z_temp.shape) == 2):
z_temp = np.reshape(z_temp, (z_temp.shape[0], 1, z_temp.shape[1]))
z_temp = z_temp.transpose([1,0,2])
z_temp_fixed_length[0:len(mini_batch), 0:z_temp.shape[1], :] = z_temp
z_batches = np.vstack([z_batches, z_temp_fixed_length])
return np.array(z_batches)
# the guide q(z_{1:T} | x_{1:T}) (i.e. the variational distribution)
def get_latents(self, mini_batch, mini_batch_reversed, mini_batch_mask,
mini_batch_seq_lengths, annealing_factor=1.0):
# this is the number of time steps we need to process in the mini-batch
T_max = mini_batch.size(1)
self.rnn.eval()
# register all PyTorch (sub)modules with pyro
pyro.module("dmm", self)
# if on gpu we need the fully broadcast view of the rnn initial state
# to be in contiguous gpu memory
h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous()
# push the observed x's through the rnn;
# rnn_output contains the hidden state at each time step
rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
# reverse the time-ordering in the hidden state and un-pack it
rnn_output = poly.pad_and_reverse(rnn_output, mini_batch_seq_lengths)
# set z_prev = z_q_0 to setup the recursive conditioning in q(z_t |...)
z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))
z_all = []
# we enclose all the sample statements in the guide in a iarange.
# this marks that each datapoint is conditionally independent of the others.
with pyro.iarange("z_minibatch", len(mini_batch)):
# sample the latents z one time step at a time
for t in range(1, T_max + 1):
# the next two lines assemble the distribution q(z_t | z_{t-1}, x_{t:T})
z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t - 1, :])
# if we are using normalizing flows, we apply the sequence of transformations
# parameterized by self.iafs to the base distribution defined in the previous line
# to yield a transformed distribution that we use for q(z_t|...)
if len(self.iafs) > 0:
z_dist = TransformedDistribution(dist.Normal(z_loc, z_scale), self.iafs)
else:
z_dist = dist.Normal(z_loc, z_scale)
assert z_dist.event_shape == ()
assert z_dist.batch_shape == (len(mini_batch), self.z_q_0.size(0))
# sample z_t from the distribution z_dist
with pyro.poutine.scale(scale=annealing_factor):
#torch.manual_seed(42)
z_t = pyro.sample("z_%d" % t,
z_dist.mask(mini_batch_mask[:, t - 1:t])
.independent(1))
# the latent sampled at this time step will be conditioned upon in the next time step
# so keep track of it
z_prev = z_t
z_all.append(self.to_numpy(z_t.clone().detach()))
z_all = np.squeeze(np.array(z_all))
self.rnn.train()
return z_all
And the modified do_evaluation looks like this:
# helper function for doing evaluation
def do_evaluation():
# put the RNN into evaluation mode (i.e. turn off drop-out if applicable)
dmm.rnn.eval()
# compute the validation and test loss n_samples many times
val_nll = elbo.evaluate_loss(val_batch, val_batch_reversed, val_batch_mask,
val_seq_lengths) / np.sum(val_seq_lengths)
test_nll = elbo.evaluate_loss(test_batch, test_batch_reversed, test_batch_mask,
test_seq_lengths) / np.sum(test_seq_lengths)
# put the RNN back into training mode (i.e. turn on drop-out if applicable)
dmm.rnn.train()
torch.manual_seed(42)
z_batch_1 = dmm.transform_batch(val_data_sequences, val_seq_lengths, mini_batch_size=10)
torch.manual_seed(42)
z_batch_2 = dmm.transform_batch(val_data_sequences, val_seq_lengths, mini_batch_size=1)
print("Distance between the two latent representations:{}".format(torch.dist(torch.tensor(z_batch_1).type(torch.Tensor), torch.tensor(z_batch_2).type(torch.Tensor), p=2)))
return val_nll, test_nll
Does this seem to be a bug or am I missing something?