def model(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lens, annealing_factor=1.0):
T = mini_batch.size(1)
pyro.module('dmm', self)
z_prev = self.z_0.expand(mini_batch.size(0), self.z_0.size(0))
with pyro.plate("z_minibatch", len(mini_batch)):
for t in pyro.markov(range(1, T + 1)):
z_loc, z_scale = self.transitor(z_prev)
with poutine.scale(scale=annealing_factor):
z_dist = dist.Normal(
loc=z_loc,
scale=z_scale
).mask(
mini_batch_mask[:, t-1:t]
).to_event(1)
z_t = pyro.sample(
name=f'z_{t}',
fn=z_dist
)
# compute the emission -> probability of win/loss
emission_t = self.emitter(z_t)
# sample the observation
pyro.sample(
name=f'obs_x_{t}',
fn=dist.Bernoulli(
probs=emission_t
).mask(
mini_batch_mask[:, t-1:t]
).to_event(1),
obs=mini_batch[:, t-1, 0]
)
z_prev = z_t
def guide(self, mini_batch, mini_batch_reversed, mini_batch_mask, mini_batch_seq_lens, annealing_factor=1.0):
T = mini_batch.size(1)
pyro.module('dmm', self)
h_0_contig = self.h_0.expand(1, mini_batch.size(0), self.rnn.hidden_size).contiguous().to(mini_batch.device)
rnn_output, _ = self.rnn(mini_batch_reversed, h_0_contig)
rnn_output = pad_and_reverse(rnn_output, mini_batch_seq_lens)
rnn_output = self.layer_norm(rnn_output)
z_prev = self.z_q_0.expand(mini_batch.size(0), self.z_q_0.size(0))
with pyro.plate("z_minibatch", len(mini_batch)):
for t in pyro.markov(range(1, T + 1)):
z_loc, z_scale = self.combiner(z_prev, rnn_output[:, t-1, :])
with pyro.poutine.scale(scale=annealing_factor):
z_dist = dist.Normal(
loc=z_loc,
scale=z_scale
).mask(
mini_batch_mask[:, t-1:t]
).to_event(1)
z_t = pyro.sample(
name=f'z_{t}',
fn=z_dist
)
z_prev = z_t
This code is the main process in my model. It is a custom DMM model. I am using pyro.optim.ClippedAdam and have tried clip_norm = 0.00001, 1.0, 10 and a bunch of values in between, but the gradients are always regardless and don’t appear to change when I change the clip_norm. Below is my model initialization and the config.yaml I am using.
device = torch_device('cpu')
# instantiate the dmm
dmm = DMM(
input_dim=config['model_params']['input_dim'],
latent_dim=config['model_params']['latent_dim'],
hidden_dim=config['model_params']['hidden_dim'],
emission_dim=config['model_params']['emission_dim'],
transition_dim=config['model_params']['transition_dim'],
rnn_dim=config['model_params']['rnn_dim'],
rnn_dropout_rate=config['model_params']['rnn_dropout_rate'],
).to(device)
adam = ClippedAdam(config['optimization_params'])
elbo = TraceEnum_ELBO()
dmm_guide = config_enumerate(
dmm.guide,
default="parallel",
num_samples=config['evaluation_params']['tmc_num_samples'],
expand=False
)
svi = SVI(dmm.model, dmm_guide, adam, loss=elbo)
# config.yaml
model_params:
input_dim: 51
latent_dim: 64
hidden_dim: 16
emission_dim: 1
transition_dim: 128
rnn_dim: 256
rnn_dropout_rate: 0.0
training_params:
num_epochs: 10
mini_batch_size: 128
annealing_epochs: 100
minimum_annealing_factor: 0.2
checkpoint_freq: 0
test_frequency: 1
cuda: True
optimization_params:
lr: 0.00003
betas: [0.96, 0.999]
clip_norm: 1.0
lrd: 0.99996
weight_decay: 1.0
evaluation_params:
tmc_num_samples: 10