SVI converges in complex discrete model but infer_discrete results are nonsense

Here is the model and guide, the SVI training converges to 0 but the results of the inference are nonsense and don’t make sense in the context of the SVI loss, which I am finding somewhat puzzling. Are there likely to be issues training such a large model with for instance 16 hidden dimensions, making some of the matrices 16x16 with so many steps from the parameters to the observations?

def PAHMM_model(sequences, args):
    K = args.hidden_dim
    num_sequences, lengths, data_dim = sequences.shape
    lengths = lengths//2
    data1 = torch.reshape(sequences, (num_sequences, lengths, 2))[:, :, 0]
    data2 = torch.reshape(sequences, (num_sequences, lengths, 2))[:, :, 1]

    print(data1)

    print(data2)




    #lay out the parameters of the various distributions
    sigma = pyro.param("sigma", torch.tensor([0.01],device=device),constraint=constraints.unit_interval)

    muy = pyro.param("muy", torch.rand(K,device=device),constraint=constraints.simplex)

    tmuxgivy = pyro.param("tmuxgivy", torch.eye(K,device=device)*0.9 + torch.rand((K,K),device=device),constraint=constraints.simplex)

    thetay = pyro.param("thetay", torch.rand(K,device=device)*0.1,constraint=constraints.unit_interval)

    epsilonyx = pyro.param("epsilonyx",torch.rand((K,K),device=device),constraint=constraints.unit_interval)

    muovtheta = torch.sum(torch.divide(muy, thetay))



    #give the initial distributions for one halpotype only
    def inhidsdist():
        mu_prob = torch.div(torch.div(muy,thetay),muovtheta)
        return mu_prob
        #returns KxK matrix for the probabilities of the initial hidden states

    def inemdis():
        return tmuxgivy

    #swermod is deprecated since we can take advantange of the simplex restriction hence just use sigma in its stead

    #this is the transition model for the haplotypes, y and does not include any of the observed genotypes
    def trans():
        temp = torch.outer(thetay, muy)
        diagonal = torch.ones(args.hidden_dim, device = device)-thetay
        temptrans = temp + torch.diag(diagonal)
        return temptrans

    #torch implementation of emission probabilities, in order to reduce memory consumption this has to done with vectorization
    def empromod():
        #this makes the off-diagonal elements of the matrix
        temphap = torch.tile(torch.reshape(torch.transpose(tmuxgivy,0,1),(1,K,1,K)),(K,1,K,1))
        #now deal with the diagonals, use a mask!
        mask_idx = torch.tile(torch.reshape(torch.eye(K,dtype=torch.bool,device=device),(1,1,K,K)),(K,K,1,1))
        diagonal = torch.tile(torch.reshape(torch.einsum('ij,ki -> jki',epsilonyx,torch.transpose(tmuxgivy,0,1)),(K,K,K,1)),(1,1,1,K))
        temphap1 = torch.where(mask_idx,diagonal,temphap)
        #do the same trick with the other more restrictive diagonal
        mask_idx_inner = torch.tile(torch.reshape(torch.eye(K,dtype=torch.bool,device=device),(K,K,1,1)),(1,1,K,K))
        mask_idx_diagonal = torch.where(mask_idx,mask_idx_inner,torch.zeros((K,K,K,K),dtype=bool,device=device))
        inner_diagonal = torch.ones((K,K,K,K),device=device) - torch.tile(torch.reshape(torch.permute(epsilonyx,(1,0)),(1,K,K,1)),(K,1,1,K))
        inner_diagonal_final = torch.where(~mask_idx_diagonal,torch.zeros((K,K,K,K),device=device),inner_diagonal)
        temphap2 = temphap1 + inner_diagonal_final
        emprob=torch.permute(temphap2,(0,2,3,1))
        return emprob




    #transition function that goes from y(t-1), s(t-1) and gives y(t) bacially by turning the conditional probabilities
    #into a 2D matrix
    def transition():
        return trans()

    #function that gives x(t) in terms of y(t), s(t), x(t-1), y(t-1)
    def emission():
        return empromod()

    probs_x = transition()

    probs_y = emission()

    probs_s = sigma



    with pyro.plate("sequences", num_sequences) as batch:
        prob_init = inhidsdist()

        s = pyro.sample("s_0", dist.Bernoulli(probs_s),
                        infer={"enumerate": "parallel"},
                        )

        x1 = pyro.sample("x1_0", dist.Categorical(prob_init), infer={"enumerate": "parallel"},
                         )
        x2 = pyro.sample("x2_0", dist.Categorical(prob_init), infer={"enumerate": "parallel"},
                         )
        x1_list = []
        x2_list = []
        x1_list.append(x1)
        x2_list.append(x2)

        y1 = pyro.sample("y1_0", dist.Categorical(inemdis()[x1]), obs=data1[batch, 0],
                         )
        y2 = pyro.sample("y2_0", dist.Categorical(inemdis()[x2]), obs=data2[batch, 0],
                         )
        for t in pyro.markov(range(1,lengths)):
            x1_help = ((1 - s) * x1 + s * x2).long()
            x2_help = ((1 - s) * x2 + s * x1).long()
            probs_x1_t = probs_x[x1_help]
            # pyro.sample returns a tensor since it inherits from torch distribution
            x1 = pyro.sample(
                "x1_{}".format(t),
                dist.Categorical(probs_x1_t),
                infer={"enumerate": "parallel"},
            )
            probs_x2_t = probs_x[x2_help]
            x2 = pyro.sample(
                "x2_{}".format(t),
                dist.Categorical(probs_x2_t),
                infer={"enumerate": "parallel"},
            )
            y1_help = ((1 - s) * y1 + s * y2).long()
            y2_help = ((1 - s) * y2 + s * y1).long()
            s = pyro.sample("s_{}".format(t), dist.Bernoulli(probs_s),
                            infer={"enumerate": "parallel"},
            )
            probs_y1_t = probs_y[y1_help, x1_help, x1]
            y1 = pyro.sample(
                "y1_{}".format(t),
                dist.Categorical(probs_y1_t),
                obs=data1[batch, t],
            )
            probs_y2_t = probs_y[y2_help, x2_help, x2]
            y2 = pyro.sample(
                "y2_{}".format(t),
                dist.Categorical(probs_y2_t),
                obs=data2[batch, t],
            )
            x1_list.append(x1)
            x2_list.append(x2)
    return x1_list, x2_list

def PAHMM_guide(sequences,args):
    K = args.hidden_dim
    num_sequences, lengths, data_dim = sequences.shape
    lengths = lengths // 2

    # lay out the parameters of the various distributions
    sigma = pyro.param("sigma", torch.tensor([0.01], device=device), constraint=constraints.unit_interval)

    muy = pyro.param("muy", torch.rand(K, device=device), constraint=constraints.simplex)

    tmuxgivy = pyro.param("tmuxgivy", torch.eye(K, device=device) * 0.9 + torch.rand((K, K), device=device),
                          constraint=constraints.simplex)

    thetay = pyro.param("thetay", torch.rand(K, device=device) * 0.1, constraint=constraints.unit_interval)

    epsilonyx = pyro.param("epsilonyx", torch.rand((K, K), device=device), constraint=constraints.unit_interval)

    muovtheta = torch.sum(torch.divide(muy, thetay))

    # give the initial distributions for one halpotype only
    def inhidsdist():
        mu_prob = torch.div(torch.div(muy, thetay), muovtheta)
        return mu_prob
        # returns KxK matrix for the probabilities of the initial hidden states

    def inemdis():
        return tmuxgivy

    # swermod is deprecated since we can take advantange of the simplex restriction hence just use sigma in its stead

    # this is the transition model for the haplotypes, y and does not include any of the observed genotypes
    def trans():
        temp = torch.outer(thetay, muy)
        diagonal = torch.ones(args.hidden_dim, device = device) - thetay
        temptrans = temp + torch.diag(diagonal)
        return temptrans

    # torch implementation of emission probabilities, in order to reduce memory consumption this has to done with vectorization
    def empromod():
        # this makes the off-diagonal elements of the matrix
        temphap = torch.tile(torch.reshape(torch.transpose(tmuxgivy, 0, 1), (1, K, 1, K)), (K, 1, K, 1))
        # now deal with the diagonals, use a mask!
        mask_idx = torch.tile(torch.reshape(torch.eye(K, dtype=torch.bool, device=device), (1, 1, K, K)), (K, K, 1, 1))
        diagonal = torch.tile(
            torch.reshape(torch.einsum('ij,ki -> jki', epsilonyx, torch.transpose(tmuxgivy, 0, 1)), (K, K, K, 1)),
            (1, 1, 1, K))
        temphap1 = torch.where(mask_idx, diagonal, temphap)
        # do the same trick with the other more restrictive diagonal
        mask_idx_inner = torch.tile(torch.reshape(torch.eye(K, dtype=torch.bool, device=device), (K, K, 1, 1)),
                                    (1, 1, K, K))
        mask_idx_diagonal = torch.where(mask_idx, mask_idx_inner, torch.zeros((K, K, K, K), dtype=bool, device=device))
        inner_diagonal = torch.ones((K, K, K, K), device=device) - torch.tile(
            torch.reshape(torch.permute(epsilonyx, (1, 0)), (1, K, K, 1)), (K, 1, 1, K))
        inner_diagonal_final = torch.where(~mask_idx_diagonal, torch.zeros((K, K, K, K), device=device), inner_diagonal)
        temphap2 = temphap1 + inner_diagonal_final
        emprob = torch.permute(temphap2, (0, 2, 3, 1))
        return emprob

    # transition function that goes from y(t-1), s(t-1) and gives y(t) basically by turning the conditional probabilities
    # into a 2D matrix
    def transition():
        return trans()

    # function that gives x(t) in terms of y(t), s(t), x(t-1), y(t-1)
    def emission():
        return empromod()

    probs_x = transition()

    probs_y = emission()

    probs_s = sigma

    with pyro.plate("sequences", num_sequences) as batch:
        prob_init = inhidsdist()

        s = pyro.sample("s_0", dist.Bernoulli(probs_s),
                        infer={"enumerate": "parallel"},
                        )

        x1 = pyro.sample("x1_0", dist.Categorical(prob_init),
                         infer={"enumerate": "parallel"},
                         )
        x2 = pyro.sample("x2_0", dist.Categorical(prob_init),
                         infer={"enumerate": "parallel"},
                         )
        x1_list = []
        x2_list = []
        x1_list.append(x1)
        x2_list.append(x2)

        for t in pyro.markov(range(1, lengths)):
            x1_help = ((1 - s) * x1 + s * x2).long()
            x2_help = ((1 - s) * x2 + s * x1).long()
            probs_x1_t = probs_x[x1_help]
            # pyro.sample returns a tensor since it inherits from torch distribution
            x1 = pyro.sample(
                "x1_{}".format(t),
                dist.Categorical(probs_x1_t),
                infer={"enumerate": "parallel"},
            )
            probs_x2_t = probs_x[x2_help]
            x2 = pyro.sample(
                "x2_{}".format(t),
                dist.Categorical(probs_x2_t),
                infer={"enumerate": "parallel"},
            )
            s = pyro.sample("s_{}".format(t), dist.Bernoulli(probs_s),
                infer={"enumerate": "parallel"},
            )

Apologies, in case something similar has been covered before; I have checked to see if anything similar has been posted and not found anything. Any assistance would be greatly appreciated.

it seems very unlikely that anyone is going to want to read 250 lines of code unless you give a lot more context (and even then you’re asking for a lot). what does nonsense mean? what are you modeling? etc

1 Like

Apologies. So this is a genetic sequence model including switch errors and recombination events. It is intended to act as a smoothing layer to a prediction layer, meaning that a lot of the sequences will be repeats. So when the inference comes out as nonsense I mean that you have different observed sequences leading to the same inferred hidden sequences. I would expect possibly that ELBO loss should only converge to zero if the joint probability P(X,Z) is close to 1 right?

Given the length of the code there is always some possibility that the model sections contain an error, but I have tested them and not found any. Having just got into the framework over the last couple of months, I’m a bit sceptical that I haven’t committed some howling error in how I’ve tried to use the framework. I have read the tutorials, forum posts and the source code and thought what I was doing seemed sensible. Obviously, I was encouraged to see the ELBO loss tend to 0.

In terms of the code the key parts are at the bottom of both the model and the guide, the upper parts are just helper functions to construct the emission and transition parts of what is almost a hidden Markov process (but not quite). I decided to use Pyro for this since I saw that you could use Pyro for Bayesian networks and it would save me rolling two time steps into one to turn it into a true hidden Markov process and end up with a giant state and transition matrices.

Shortened version of the code, hopefully this is more instructive:

def PAHMM_model(sequences, args):

    #lay out the parameters of the various distributions
    sigma = pyro.param("sigma", torch.tensor([0.01],device=device),constraint=constraints.unit_interval)

    muy = pyro.param("muy", torch.rand(K,device=device),constraint=constraints.simplex)

    tmuxgivy = pyro.param("tmuxgivy", torch.eye(K,device=device)*0.9 + torch.rand((K,K),device=device),constraint=constraints.simplex)

    thetay = pyro.param("thetay", torch.rand(K,device=device)*0.1,constraint=constraints.unit_interval)

    epsilonyx = pyro.param("epsilonyx",torch.rand((K,K),device=device),constraint=constraints.unit_interval)

 
    def transition():
        # some basic torch operations on the parameters

    def emission():
        # some more torch operations involving the parameters

    probs_x = transition()

    probs_y = emission()

    probs_s = sigma



    with pyro.plate("sequences", num_sequences) as batch:
        prob_init = inhidsdist()

        s = pyro.sample("s_0", dist.Bernoulli(probs_s),
                        infer={"enumerate": "parallel"},
                        )

        x1 = pyro.sample("x1_0", dist.Categorical(prob_init), infer={"enumerate": "parallel"},
                         )
        x2 = pyro.sample("x2_0", dist.Categorical(prob_init), infer={"enumerate": "parallel"},
                         )
        x1_list = []
        x2_list = []
        x1_list.append(x1)
        x2_list.append(x2)

        y1 = pyro.sample("y1_0", dist.Categorical(inemdis()[x1]), obs=data1[batch, 0],
                         )
        y2 = pyro.sample("y2_0", dist.Categorical(inemdis()[x2]), obs=data2[batch, 0],
                         )
        for t in pyro.markov(range(1,lengths)):
            x1_help = ((1 - s) * x1 + s * x2).long()
            x2_help = ((1 - s) * x2 + s * x1).long()
            probs_x1_t = probs_x[x1_help]
            # pyro.sample returns a tensor since it inherits from torch distribution
            x1 = pyro.sample(
                "x1_{}".format(t),
                dist.Categorical(probs_x1_t),
                infer={"enumerate": "parallel"},
            )
            probs_x2_t = probs_x[x2_help]
            x2 = pyro.sample(
                "x2_{}".format(t),
                dist.Categorical(probs_x2_t),
                infer={"enumerate": "parallel"},
            )
            y1_help = ((1 - s) * y1 + s * y2).long()
            y2_help = ((1 - s) * y2 + s * y1).long()
            s = pyro.sample("s_{}".format(t), dist.Bernoulli(probs_s),
                            infer={"enumerate": "parallel"},
            )
            probs_y1_t = probs_y[y1_help, x1_help, x1]
            y1 = pyro.sample(
                "y1_{}".format(t),
                dist.Categorical(probs_y1_t),
                obs=data1[batch, t],
            )
            probs_y2_t = probs_y[y2_help, x2_help, x2]
            y2 = pyro.sample(
                "y2_{}".format(t),
                dist.Categorical(probs_y2_t),
                obs=data2[batch, t],
            )
            x1_list.append(x1)
            x2_list.append(x2)
    return x1_list, x2_list

def PAHMM_guide(sequences,args):
    K = args.hidden_dim

    # lay out the parameters of the various distributions
    sigma = pyro.param("sigma", torch.tensor([0.01], device=device), constraint=constraints.unit_interval)

    muy = pyro.param("muy", torch.rand(K, device=device), constraint=constraints.simplex)

    tmuxgivy = pyro.param("tmuxgivy", torch.eye(K, device=device) * 0.9 + torch.rand((K, K), device=device),
                          constraint=constraints.simplex)

    thetay = pyro.param("thetay", torch.rand(K, device=device) * 0.1, constraint=constraints.unit_interval)

    epsilonyx = pyro.param("epsilonyx", torch.rand((K, K), device=device), constraint=constraints.unit_interval)

    
    def transition():
        # same torch operations as the model

    def emission():
        # same torch operations as the model

    probs_x = transition()

    probs_y = emission()

    probs_s = sigma

    with pyro.plate("sequences", num_sequences) as batch:
        prob_init = inhidsdist()

        s = pyro.sample("s_0", dist.Bernoulli(probs_s),
                        infer={"enumerate": "parallel"},
                        )

        x1 = pyro.sample("x1_0", dist.Categorical(prob_init),
                         infer={"enumerate": "parallel"},
                         )
        x2 = pyro.sample("x2_0", dist.Categorical(prob_init),
                         infer={"enumerate": "parallel"},
                         )

        for t in pyro.markov(range(1, lengths)):
            x1_help = ((1 - s) * x1 + s * x2).long()
            x2_help = ((1 - s) * x2 + s * x1).long()
            probs_x1_t = probs_x[x1_help]
            # pyro.sample returns a tensor since it inherits from torch distribution
            x1 = pyro.sample(
                "x1_{}".format(t),
                dist.Categorical(probs_x1_t),
                infer={"enumerate": "parallel"},
            )
            probs_x2_t = probs_x[x2_help]
            x2 = pyro.sample(
                "x2_{}".format(t),
                dist.Categorical(probs_x2_t),
                infer={"enumerate": "parallel"},
            )
            s = pyro.sample("s_{}".format(t), dist.Bernoulli(probs_s),
                infer={"enumerate": "parallel"},
            )

why do you have a guide? afaict you have all discrete latent variables and so you can sum these out explicitly ala hmm.py

def PAHMM_guide(sequences,args):
    pass
1 Like

Thanks for the reply. The hmm.py example uses an AutoDelta guide though, which autogenerates a discrete guide exposing all the discrete latent variables beginning with “probs_” though? I started by copying that example in terms of using an AutoDelta guide, but I guess I missed some of the finer points and exposed the params instead. When you say sum them out do you mean over the enumeration of all the values that they might take and hence I should expose the sampled latent variables in the autoguide? So that I can better understand the framework, what is the (potential) issue with manually coding a guide that gives the values of the latent variables in this way? Much appreciated.

hmm.py does a somewhat overcomplex thing in that it effectively demotes the dirichlet latent variables to penalized param statements, i.e. param statements that are regularized by MAP (in particular probs_x are still point estimates). if probs_x etc had been param statements from the get go, the guide could have been empty.

for a model with appropriate markov structure pyro can automatically enumerate/sum out/integrate out the entire discrete latent space. in effect it can compute the “optimal” guide. whenever this is possible there’s no point in providing a guide, since you’d just be attempting to learn something via optimization that you can compute in closed form.

1 Like

Thanks for the reply, I have now run it with an empty guide and the ELBO loss does not train downwards past some initial progress (not much). I have also experimented with changing the multiple indexing to a Vindex to no avail as that should make the enumeration work correctly?

I appreciate the insight, can see it in the source code now, so this was the point of enumeration and how you structured the enumeration dimensions.

I was wondering if you had any other ideas for the strange behaviour? I guess the fact that the ELBO loss tends to 0 with a hand coded guide is a result of the fact it’s an approximation or could it be that there is something going wrong with the enumeration dimensions which becomes apparent when I use an empty guide?

Thanks for the help/insight.

your model is quite complex. i suggest you start with the simplest possible model and build up complexity gradually, making sure that indexing/broadcasting is handled correctly. in particular you might benefit from (re)reading this tutorial as well as this tutorial

i suspect that with your hand coded guide (which, afaict, is basically incorrect since it doesn’t introduce variational parameters) you were being driven into a weird/trivial corner of parameter space. i don’t think you should take that “near zero elbo” as a sign of “near success”

1 Like

Thanks for the response. I have read and re-read those tutorials a lot. I think the issue is there is quite a complexity gap between those examples and the kind of real world models that make the framework useful. In principle there is nothing wrong with the model using a bunch of torch matrix operations or the overall idea? In that case I will spend more time debugging/building up the overall model.

As far as variational parameters in the guide what do you mean by that, since the guide does have named parameters (that match the model, is that the problem?) that can presumably be trained?

Thanks.

I was taking the loss going down to 0 as something weird was happening not necessarily that the inference was on the verge of working although getting close is a prerequisite :wink: I always like to explore all avenues in case I’m missing something though or I can get something out of it. It’s useful for me to know how to write a guide as I also have a modified DMM model to train as well.

by matching parameter names you’re matching parameters. you need unique parameter names if you want unique parameters.

you can use whatever torch ops you want provided you respect declared markovian structure, write code in a fully vectorized way that allows for automatic introduction of enumeration dimensions on the left, etc

1 Like

Thanks, yes I thought that might be a problem in the guide.

I will check and re-run (most likely quite a few times).

One more question, how would I prevent the automatic introduction of enumeration dimensions on the left? Manually introduce them myself to screw it all up I assume?

Much appreciated.

e.g. if you added extraneous unsqueeze statements that messed up alignment of different time slices

1 Like

Thanks for the response and all your help. Hopefully I can debug the model and get it working now.

for what it’s worth there are some hmm-like models that incorporate neural networks here if you want examples of more complex models, e.g. model7

1 Like

Thanks :smiley: