Hello,
I’m working on a dynamical variational autoencoder (a VRNN model) with pyro, which almostly followed the deep markov model example. I’m wondering how to use TraceTMC_ELBO
correctly. In the dmm, it just wraped the guide with config_enumerate
. But it raises an error when I do the same in my code. So, did I miss something?
def model(self, x, annealing_factor=1.0):
b, l, _ = x.shape
pyro.module("vrnn", self)
h = x.new_zeros([b, self.hidden_dim])
feature_x = self.feature_extra_x(x)
with pyro.plate("data", b):
for t in range(1, l + 1):
z_loc, z_scale = self.theta_norm_1(h)
with poutine.scale(None, annealing_factor):
z_t = pyro.sample(f"z_{t}", dist.Normal(z_loc, z_scale).to_event(1))
z_t = self.feature_extra_z(z_t)
z_with_h = torch.cat([z_t, h], dim=1)
x_loc, x_scale = self.theta_norm_2(z_with_h)
pyro.sample(f"obs_{t}", dist.Normal(x_loc, x_scale).to_event(1), obs=x[:, t - 1, :])
x_with_z = torch.cat([feature_x[:, t - 1, :], z_t], dim=1)
h = self.rnn(x_with_z, h)
def guide(self, x, annealing_factor=1.0):
b, l, _ = x.shape
h = x.new_zeros([b, self.hidden_dim])
feature_x = self.feature_extra_x(x)
with pyro.plate("data", b):
for t in range(1, l + 1):
xt = feature_x[:, t - 1, :]
h_with_x = torch.cat([h, xt], dim=1)
z_loc, z_scale = self.phi_norm(h_with_x)
with poutine.scale(None, annealing_factor):
z_t = pyro.sample(f"z_{t}", dist.Normal(z_loc, z_scale).to_event(1))
z_t = self.feature_extra_z(z_t)
x_with_z = torch.cat([xt, z_t], dim=1)
h = self.rnn(x_with_z, h)