Here is the model and guide, the SVI training converges to 0 but the results of the inference are nonsense and don’t make sense in the context of the SVI loss, which I am finding somewhat puzzling. Are there likely to be issues training such a large model with for instance 16 hidden dimensions, making some of the matrices 16x16 with so many steps from the parameters to the observations?
def PAHMM_model(sequences, args):
K = args.hidden_dim
num_sequences, lengths, data_dim = sequences.shape
lengths = lengths//2
data1 = torch.reshape(sequences, (num_sequences, lengths, 2))[:, :, 0]
data2 = torch.reshape(sequences, (num_sequences, lengths, 2))[:, :, 1]
print(data1)
print(data2)
#lay out the parameters of the various distributions
sigma = pyro.param("sigma", torch.tensor([0.01],device=device),constraint=constraints.unit_interval)
muy = pyro.param("muy", torch.rand(K,device=device),constraint=constraints.simplex)
tmuxgivy = pyro.param("tmuxgivy", torch.eye(K,device=device)*0.9 + torch.rand((K,K),device=device),constraint=constraints.simplex)
thetay = pyro.param("thetay", torch.rand(K,device=device)*0.1,constraint=constraints.unit_interval)
epsilonyx = pyro.param("epsilonyx",torch.rand((K,K),device=device),constraint=constraints.unit_interval)
muovtheta = torch.sum(torch.divide(muy, thetay))
#give the initial distributions for one halpotype only
def inhidsdist():
mu_prob = torch.div(torch.div(muy,thetay),muovtheta)
return mu_prob
#returns KxK matrix for the probabilities of the initial hidden states
def inemdis():
return tmuxgivy
#swermod is deprecated since we can take advantange of the simplex restriction hence just use sigma in its stead
#this is the transition model for the haplotypes, y and does not include any of the observed genotypes
def trans():
temp = torch.outer(thetay, muy)
diagonal = torch.ones(args.hidden_dim, device = device)-thetay
temptrans = temp + torch.diag(diagonal)
return temptrans
#torch implementation of emission probabilities, in order to reduce memory consumption this has to done with vectorization
def empromod():
#this makes the off-diagonal elements of the matrix
temphap = torch.tile(torch.reshape(torch.transpose(tmuxgivy,0,1),(1,K,1,K)),(K,1,K,1))
#now deal with the diagonals, use a mask!
mask_idx = torch.tile(torch.reshape(torch.eye(K,dtype=torch.bool,device=device),(1,1,K,K)),(K,K,1,1))
diagonal = torch.tile(torch.reshape(torch.einsum('ij,ki -> jki',epsilonyx,torch.transpose(tmuxgivy,0,1)),(K,K,K,1)),(1,1,1,K))
temphap1 = torch.where(mask_idx,diagonal,temphap)
#do the same trick with the other more restrictive diagonal
mask_idx_inner = torch.tile(torch.reshape(torch.eye(K,dtype=torch.bool,device=device),(K,K,1,1)),(1,1,K,K))
mask_idx_diagonal = torch.where(mask_idx,mask_idx_inner,torch.zeros((K,K,K,K),dtype=bool,device=device))
inner_diagonal = torch.ones((K,K,K,K),device=device) - torch.tile(torch.reshape(torch.permute(epsilonyx,(1,0)),(1,K,K,1)),(K,1,1,K))
inner_diagonal_final = torch.where(~mask_idx_diagonal,torch.zeros((K,K,K,K),device=device),inner_diagonal)
temphap2 = temphap1 + inner_diagonal_final
emprob=torch.permute(temphap2,(0,2,3,1))
return emprob
#transition function that goes from y(t-1), s(t-1) and gives y(t) bacially by turning the conditional probabilities
#into a 2D matrix
def transition():
return trans()
#function that gives x(t) in terms of y(t), s(t), x(t-1), y(t-1)
def emission():
return empromod()
probs_x = transition()
probs_y = emission()
probs_s = sigma
with pyro.plate("sequences", num_sequences) as batch:
prob_init = inhidsdist()
s = pyro.sample("s_0", dist.Bernoulli(probs_s),
infer={"enumerate": "parallel"},
)
x1 = pyro.sample("x1_0", dist.Categorical(prob_init), infer={"enumerate": "parallel"},
)
x2 = pyro.sample("x2_0", dist.Categorical(prob_init), infer={"enumerate": "parallel"},
)
x1_list = []
x2_list = []
x1_list.append(x1)
x2_list.append(x2)
y1 = pyro.sample("y1_0", dist.Categorical(inemdis()[x1]), obs=data1[batch, 0],
)
y2 = pyro.sample("y2_0", dist.Categorical(inemdis()[x2]), obs=data2[batch, 0],
)
for t in pyro.markov(range(1,lengths)):
x1_help = ((1 - s) * x1 + s * x2).long()
x2_help = ((1 - s) * x2 + s * x1).long()
probs_x1_t = probs_x[x1_help]
# pyro.sample returns a tensor since it inherits from torch distribution
x1 = pyro.sample(
"x1_{}".format(t),
dist.Categorical(probs_x1_t),
infer={"enumerate": "parallel"},
)
probs_x2_t = probs_x[x2_help]
x2 = pyro.sample(
"x2_{}".format(t),
dist.Categorical(probs_x2_t),
infer={"enumerate": "parallel"},
)
y1_help = ((1 - s) * y1 + s * y2).long()
y2_help = ((1 - s) * y2 + s * y1).long()
s = pyro.sample("s_{}".format(t), dist.Bernoulli(probs_s),
infer={"enumerate": "parallel"},
)
probs_y1_t = probs_y[y1_help, x1_help, x1]
y1 = pyro.sample(
"y1_{}".format(t),
dist.Categorical(probs_y1_t),
obs=data1[batch, t],
)
probs_y2_t = probs_y[y2_help, x2_help, x2]
y2 = pyro.sample(
"y2_{}".format(t),
dist.Categorical(probs_y2_t),
obs=data2[batch, t],
)
x1_list.append(x1)
x2_list.append(x2)
return x1_list, x2_list
def PAHMM_guide(sequences,args):
K = args.hidden_dim
num_sequences, lengths, data_dim = sequences.shape
lengths = lengths // 2
# lay out the parameters of the various distributions
sigma = pyro.param("sigma", torch.tensor([0.01], device=device), constraint=constraints.unit_interval)
muy = pyro.param("muy", torch.rand(K, device=device), constraint=constraints.simplex)
tmuxgivy = pyro.param("tmuxgivy", torch.eye(K, device=device) * 0.9 + torch.rand((K, K), device=device),
constraint=constraints.simplex)
thetay = pyro.param("thetay", torch.rand(K, device=device) * 0.1, constraint=constraints.unit_interval)
epsilonyx = pyro.param("epsilonyx", torch.rand((K, K), device=device), constraint=constraints.unit_interval)
muovtheta = torch.sum(torch.divide(muy, thetay))
# give the initial distributions for one halpotype only
def inhidsdist():
mu_prob = torch.div(torch.div(muy, thetay), muovtheta)
return mu_prob
# returns KxK matrix for the probabilities of the initial hidden states
def inemdis():
return tmuxgivy
# swermod is deprecated since we can take advantange of the simplex restriction hence just use sigma in its stead
# this is the transition model for the haplotypes, y and does not include any of the observed genotypes
def trans():
temp = torch.outer(thetay, muy)
diagonal = torch.ones(args.hidden_dim, device = device) - thetay
temptrans = temp + torch.diag(diagonal)
return temptrans
# torch implementation of emission probabilities, in order to reduce memory consumption this has to done with vectorization
def empromod():
# this makes the off-diagonal elements of the matrix
temphap = torch.tile(torch.reshape(torch.transpose(tmuxgivy, 0, 1), (1, K, 1, K)), (K, 1, K, 1))
# now deal with the diagonals, use a mask!
mask_idx = torch.tile(torch.reshape(torch.eye(K, dtype=torch.bool, device=device), (1, 1, K, K)), (K, K, 1, 1))
diagonal = torch.tile(
torch.reshape(torch.einsum('ij,ki -> jki', epsilonyx, torch.transpose(tmuxgivy, 0, 1)), (K, K, K, 1)),
(1, 1, 1, K))
temphap1 = torch.where(mask_idx, diagonal, temphap)
# do the same trick with the other more restrictive diagonal
mask_idx_inner = torch.tile(torch.reshape(torch.eye(K, dtype=torch.bool, device=device), (K, K, 1, 1)),
(1, 1, K, K))
mask_idx_diagonal = torch.where(mask_idx, mask_idx_inner, torch.zeros((K, K, K, K), dtype=bool, device=device))
inner_diagonal = torch.ones((K, K, K, K), device=device) - torch.tile(
torch.reshape(torch.permute(epsilonyx, (1, 0)), (1, K, K, 1)), (K, 1, 1, K))
inner_diagonal_final = torch.where(~mask_idx_diagonal, torch.zeros((K, K, K, K), device=device), inner_diagonal)
temphap2 = temphap1 + inner_diagonal_final
emprob = torch.permute(temphap2, (0, 2, 3, 1))
return emprob
# transition function that goes from y(t-1), s(t-1) and gives y(t) basically by turning the conditional probabilities
# into a 2D matrix
def transition():
return trans()
# function that gives x(t) in terms of y(t), s(t), x(t-1), y(t-1)
def emission():
return empromod()
probs_x = transition()
probs_y = emission()
probs_s = sigma
with pyro.plate("sequences", num_sequences) as batch:
prob_init = inhidsdist()
s = pyro.sample("s_0", dist.Bernoulli(probs_s),
infer={"enumerate": "parallel"},
)
x1 = pyro.sample("x1_0", dist.Categorical(prob_init),
infer={"enumerate": "parallel"},
)
x2 = pyro.sample("x2_0", dist.Categorical(prob_init),
infer={"enumerate": "parallel"},
)
x1_list = []
x2_list = []
x1_list.append(x1)
x2_list.append(x2)
for t in pyro.markov(range(1, lengths)):
x1_help = ((1 - s) * x1 + s * x2).long()
x2_help = ((1 - s) * x2 + s * x1).long()
probs_x1_t = probs_x[x1_help]
# pyro.sample returns a tensor since it inherits from torch distribution
x1 = pyro.sample(
"x1_{}".format(t),
dist.Categorical(probs_x1_t),
infer={"enumerate": "parallel"},
)
probs_x2_t = probs_x[x2_help]
x2 = pyro.sample(
"x2_{}".format(t),
dist.Categorical(probs_x2_t),
infer={"enumerate": "parallel"},
)
s = pyro.sample("s_{}".format(t), dist.Bernoulli(probs_s),
infer={"enumerate": "parallel"},
)