For reference, here is what my model looks like
def model(self, L, X, pheno):
"""
:param L: Length of each Time Series in batch, B
:param X: Time Series, T_max x B x D
:param pheno: Binary Multi-label class targets, B x 25
"""
pyro.module('dmm', self, update_module_params=True)
B = pheno.size(0)
T_max = X.size(0)
df = pyro.param('df', torch.tensor(2.0, device=X.device),
constraint=constraints.positive)
mu_0 = pyro.param('mu_0', torch.zeros(1, self.latent_size,
device=X.device))
sigma_0 = pyro.param('sigma_0', torch.zeros(1, self.latent_size,
device=X.device))
mu_0 = mu_0.expand(B, -1)
sigma_0 = sigma_0.expand(B, -1)
z_pheno = torch.zeros(B, self.latent_size).to(X.device)
h_prev = self.h_0.expand(B, -1).to(X.device)
with pyro.plate('X_plate', B, device=X.device):
z_prev = pyro.sample('Z_0', dist.Normal(mu_0, sigma_0.exp()).to_event(1))
for t in range(T_max):
with poutine.mask(mask=(t < L)):
h_t, z_mu, z_log_var = self.transition(z_prev, h_prev)
z_dist = dist.Normal(z_mu, z_log_var.exp()).to_event(1)
z_t = pyro.sample('Z_{}'.format(t + 1), z_dist)
x_mu, x_log_var = self.emitter(z_t)
x_dist = dist.StudentT(df, x_mu, x_log_var.exp()).to_event(1)
pyro.sample('obs_X_{}'.format(t + 1), x_dist, obs=X[t, :, :13])
missingness_p = self.missing(z_t)
m_dist = dist.Bernoulli(missingness_p).to_event(1)
pyro.sample('obs_I_{}'.format(t + 1), m_dist, obs=X[t, :, 13:13 + 12])
z_pheno[t == L] = z_t[t == L]
h_prev = h_t
z_prev = z_t
with pyro.plate('P_plate', 25, device=pheno.device):
with pyro.plate('PX_plate', B, device=X.device):
pheno_p = self.pheno(z_pheno)
pheno_dist = dist.Bernoulli(pheno_p)
pyro.sample('obs_P', pheno_dist, obs=pheno)