Hi,
I’m a bit new to Pyro and have been working on a HMM for protein structure prediction, which will have 4 observations per (discrete) timestep when done:
The observations are as follows:
One continuous observation containing two values. This is sampled from a custom distribution I have implemented (bivariate von mises) and has the look: tensor([ [phi , psi ] ]), so shape (1,2). The values are in radians.
and
Three discrete observations sampled from each their own categorical distribution.
It’s going alright when I only have categorical observations, but I run into massive problems when I try to do the same model but with the continuous observations.
What am I doing wrong?
def model_0(seq_AA, seq_DSSP, seq_dihedral, lengths, args, batch_size=None, include_prior=True):
assert not torch._C._get_tracing_state()
num_sequences, max_length, data_dim = seq_AA.shape
data_dim_AA = 20 #for aminoacids
data_dim_dihedral = 2
with poutine.mask(mask=include_prior):
probs_x = pyro.sample("probs_x",
dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1)
.to_event(1))
probs_y_AA = pyro.sample("probs_y_AA",
dist.Beta(0.1, 0.9)
.expand([args.hidden_dim, data_dim_AA])
.to_event(2))
probs_y_means = pyro.sample("probs_y_means",
dist.VonMises(torch.tensor([0.]), torch.tensor(90.))
.expand([args.hidden_dim, data_dim_dihedral])
.to_event(2))
probs_y_lam = pyro.sample("probs_y_lam",
dist.Normal(torch.tensor([0.]), torch.tensor(1.))
.expand([args.hidden_dim, 1])
.to_event(2))
probs_y_k1k2 = pyro.sample("probs_y_k1k2",
dist.Gamma(torch.tensor([70.]), torch.tensor(1.))
.expand([args.hidden_dim, 2])
.to_event(2))
for i in pyro.plate("sequences", len(lengths), batch_size):
length = int(lengths[i].item())
sequence_AA = seq_AA[i, :length]
sequence_dihedral = seq_dihedral[i,:length]
state_x = 0
for t in pyro.markov(range(length)):
state_x = pyro.sample("state_x_{}_{}".format(i, t), dist.Categorical(Vindex(probs_x)[..., state_x, :]),
infer={"enumerate": "sequential"})
pyro.sample("y_AA_{}_{}".format(i, t), dist.Categorical(Vindex(probs_y_AA)[...,state_x.squeeze(-1), :]),
obs=sequence_AA[t])
pyro.sample("y_dihedral_{}_{}".format(i,t), BVM.BivariateVonMises(
mu = Vindex(probs_y_means)[..., state_x.squeeze(-1), 0],
nu = Vindex(probs_y_means)[..., state_x.squeeze(-1), 1],
k1 = Vindex(probs_y_k1k2)[..., state_x.squeeze(-1), 0],
k2 = Vindex(probs_y_k1k2)[..., state_x.squeeze(-1), 1],
lam = Vindex(probs_y_lam)[...,state_x.squeeze(-1),:]),
obs=(sequence_dihedral[t]).reshape([1,2]))
Edit 1:
I am also getting the error “found vars in model but not in guide” even though I’m using enumeration. Which confuses me a bit. The error is refereing to the samples from probs_x.
My code is taken from the hmm example.
Edit 2:
Sorry I forgot to mention that the final model will have 1 continuous observation and 3 discrete observations, but the current model only has 1 continuous and 1 discrete observation.