Question on local/global random variable in pyro HMM example

Hi,

I am confused by the pyro HMM example.

In the very basic model (model_1, related code shown below), it seems that the discrete state variable x_{} is shared across the batch and every sequence in a batch shares the same state.

In my understanding, in an HMM, the transition matrix and the observation matrix are global variable. However, the state variable are local, which means that different sequence would have its own state variables…

Could anyone help me with the code and HMM concepts?

Thanks!

    with pyro.iarange("sequences", len(sequences), batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        for t in range(lengths.max()):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                # On the next line, we'll overwrite the value of x with an updated
                # value. If we wanted to record all x values, we could instead
                # write x[t] = pyro.sample(...x[t-1]...).
                x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_iarange:
                    pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y[x]),
                                obs=sequences[batch, t])

Hi @ruijiang, you are correct that in an HMM the state variable is local to each sequence in a batch. In the Pyro HMM example, the state variable "x_{}".format(t) is actually a tensor if independent variables. You can read this from the model in the outermost pyro.iarange("sequences", ..., dim=-2). Inside that iarange context every variable is local and is vectorized over the 2nd dim from the right (dim=-2). Does that make sense?

BTW we’re thinking of renaming pyro.iarange to pyro.plate to make that clearer.

1 Like

Thanks @fritzo for the explanation. My confusion is gone:)

I think it is a good idea to rename iarange with plate as the usage of iarange differs significantly from irange. Previously I use irange with which we have to explicitly name local variables or add an extra batch dimension, as shown in the following examples.

    for i in pyro.irange("data_loop", len(data)):
        # observe datapoint i using the bernoulli likelihood
        pyro.sample("obs_{}".format(i), dist.Bernoulli(f), obs=data[i])

iarange is much more convenient as we don’t have to consider the ‘batch’ dimension or vectorization explicitly .

1 Like

Hi @fritzo,

I have some local variable related questions

  • in the HMM example, is there any convenient way to get the posterior of the state variable ‘x’ although it is marginalized out?

  • a more general question, how to access the posterior of local random variables if Amortized inference is not used? I tried to see the size of ‘x’ in the example, seems that the size of x has nothing to do with the batch size.

Thanks!

Hi @ruijiang,

You can examine the enumerated variables using either TraceEnum_ELBO.compute_marginals() or TraceEnum_ELBO.sample_posterior(). Note however that these are very new, and still have some known bugs regarding batching. We’re hoping to get these fully working before the next Pyro release. If you end up using these, let us know your experience so we can improve the interface!

@ruijiang I forgot: another way to access the marginalized-out variables is to train a second SVI guide that fixes the variables learned in the first guide. We have an example of this in the GMM tutorial.

Thanks @fritzo.

I am a little bit confused of what is going on if we enumerate the guide in the GMM example.

Let z_i be the discrete variable, g be the global variable, i be the indicator for independent samples.

  • with MAP estimation, it is clear that if we enumerate z in the model, we are going to find the maximum of p(y,g). After learning we get the point estimation of g.

  • with VB, if we enumerate z in the guide, the parameter \theta would vanish.

I think I had a misunderstanding of the objective function pyro optimizes if enumeration is involved. Is their any document on the objective function if enumeration is involved?


Still I gave the 'second guide ’ approach a try, but the speed is much slower compared with MAP estimation.

I am trying to implement Switching Dynamic Linear System with Pyro, see the model in Sec. 2.1 of the paper

and my related code as follows,

  • z is the discrete state and x is the continuous state vector in linear dynamic system
  • the input sequence is tensor of shape num_sequences, max_length, data_dim
  • model:
@poutine.broadcast
def model_sssm(sequences, lengths, args, batch_size=None, include_prior=True):
    num_sequences, max_length, data_dim = sequences.shape
    assert lengths.shape == (num_sequences,)
    assert lengths.max() <= max_length

    n_seg = args.num_segment
    n_lat = args.num_latent
    n_out = data_dim
    normal_std = 1e-1
    with poutine.mask(mask=torch.tensor(include_prior)):
        # transition matrix for HMM, K*K,
        hmm_dyn = pyro.sample("hmm_dyn",
                              dist.Dirichlet(0.9 * torch.eye(n_seg) + 0.1)
                                  .independent(1))
        ## SSM
        ssm_dyn = pyro.sample("ssm_dyn",
                              dist.Normal(0,normal_std).expand_by([n_seg, n_lat, n_lat])
                                  .independent(3))

        ssm_bias = pyro.sample("ssm_bias",
                              dist.Normal(0,1.).expand_by([n_seg, n_lat, 1])
                                  .independent(3))

        ssm_noise = pyro.sample("ssm_noise",
                              dist.Gamma(1e0,1e0).expand_by([n_seg, n_lat, 1])
                                  .independent(3))



        ## observation
        obs_weight = pyro.sample("obs_weight",
                              dist.Normal(0,1).expand_by([n_seg, n_out, n_lat])
                                  .independent(3))

        obs_bias = pyro.sample("obs_bias",
                              dist.Normal(0,1).expand_by([n_seg, n_out, 1])
                                  .independent(3))

        obs_noise = pyro.sample("obs_noise",
                              dist.Gamma(1e0,1e0).expand_by([n_seg, n_out, 1])
                                  .independent(3))

    with pyro.iarange("sequences", len(sequences), batch_size, dim=-2) as batch:
        ## ssm init state
        # use a (n_lat,1) matrix (instead of vector) to ease batch operation,
        # as for high dim, matmul would do batch matrix-matrix product, instead of matrix-vector product.
        #
        ssm_init = pyro.sample("ssm_init", dist.Normal(0,1).expand_by([n_lat,1]).independent(2))
        # print('===============')
        # print('ssm_init_shape:', ssm_init.shape)

        # the RVs below are all local variables.
        lengths = lengths[batch]
        z = 0
        x = ssm_init
        print('###############')
        for t in range(lengths.max()):
            with poutine.mask(mask=(t < lengths).unsqueeze(-1)):
                # On the next line, we'll overwrite the value of x with an updated
                # value. If we wanted to record all x values, we could instead
                # write x[t] = pyro.sample(...x[t-1]...).
                z = pyro.sample("z_{}".format(t), dist.Categorical(hmm_dyn[z]),
                                infer={"enumerate": "parallel"})
                x = pyro.sample("x_{}".format(t), dist.Normal(torch.matmul(ssm_dyn[z], x) + ssm_bias[z], torch.sqrt(1./ssm_noise[z])).independent(2))
                obs_lat = torch.matmul(obs_weight[z], x) + obs_bias[z]
                pyro.sample("y_{}".format(t), dist.Normal(obs_lat, torch.sqrt(1./obs_noise[z])).independent(2),
                            obs=sequences[batch, t].unsqueeze(-1))
  • guide:
    @poutine.broadcast
    @config_enumerate(default="parallel")
    def guide(sequences, lengths, args, batch_size=None, include_prior=True):
        num_sequences, max_length, data_dim = sequences.shape
        n_seg = args.num_segment

        with pyro.iarange("sequences", len(sequences), batch_size, dim=-2) as batch:
            lengths = lengths[batch]
            for t in range(lengths.max()):
                ret_prob = pyro.param('assignment_probs_{}'.format(t), torch.ones(len(lengths), 1, n_seg) / n_seg,
                                      constraint=constraints.unit_interval)
                z = pyro.sample("z_{}".format(t), dist.Categorical(ret_prob))

could you help me figure out why it is so slow?

Thanks,
Ruijiang

2 Likes

Hi @fritzo,

I also gave a TraceEnum_ELBO.compute_marginals() a try on the switching dynamic linear system, but got size issues:

  • With the following code, the batch_size of the discrete variable (z_{}) is [bs, bs], and the shape is [bs, bs, n_state] where bs refers to batch size.
r1 = elbo.compute_marginals(model, guide, sequences, lengths, args, batch_size=args.batch_size)
  • if I check the trace with the following code, the batch_shape in trace_model.nodes are OK ([bs,1] as we have dim=-2).
trace_model = poutine.trace(model).get_trace(sequences, lengths, args, batch_size=args.batch_size)
  • I print the size of the distribution and tensors in the model, the sizes are OK.
  • If I change dim=-1 in the `iarange(), related error occurs.
ValueError: at site "z_0", invalid log_prob shape
  Expected [3], actual [3, 3]

I also gave compute_marginals a try with HMM example, the shapes are OK. So I think there is something wrong with my model (posted in the previous reply), but I could find where the issues is.

Thanks.
Ruijiang

I am trying to implement Switching Dynamic Linear System with Pyro

Nice! Do you have an interest in contributing that to pyro/examples/? I think it would be easier to discuss this model in a PR rather than a forum thread.

Is their any document on the objective function if enumeration is involved?

The objective function is simply the ELBO. Even when we do MAP estimation, we maximize ELBO with a trivial delta guide.

                      p(z,x)
ELBO = sum q(z|x) log ------
        z             q(z|x)

In enumeration we split z into say three parts: non-enumerated z1, guide-enumerated z2, and model-enumerated z3.

                              sum p(z1,z2,z3,x)
                               z3
ELBO = sum sum q(z1,z2|x) log -----------------
        z1  z2                    q(z1,z2|x)

Note that sum z1 q(z1|x) is implemented via Monte Carlo, whereas sum z2 q(z2|x) is implemented via weighted enumeration in the guide, and sum z3 p(z3) is implemented via weighted enumeration in the model. I’ll try to add a clearer explanation to our upcoming enumeration tutorial.

1 Like

Hi @fritzo,

I am interested in contributing to the examples.

For the ‘switching linear dynamic models’, should I move the discussion in the forum to github issues?

1 Like

Thanks @fritzo, the explanation of enumeration with these equations are quite clear.

Just a small question on a special case,

If a variable z4 is enumerated in both model and guide, then in ELBO, there will be NO sum of p(z4) over z4. Am I correct?

‘switching linear dynamic models’ … to github issues?

Yes, github discussion would be great! A few of us Pyro devs are also working on tutorials for the upcoming 0.3 release (planned for NIPS).

If a variable z4 is enumerated in both model and guide

Hmm I’m not sure what “enumerated in both model and guide” means. Pyro supports monte carlo sampling, guide-side enumeration, and model-side enumeration. In the above example, z2 is enumerated in the guide and replayed in the model. It is not summed-out in the log numerator, but it does appear in p(z2).

If a variable z4 is enumerated in both model and guide

In GMM example, @config_enumerate(default='parallel') applies to discrete r.v. assignment in both model and guide:

@config_enumerate(default='parallel')
@poutine.broadcast
def model(data):
    # Global variables.
    weights = pyro.sample('weights', dist.Dirichlet(0.5 * torch.ones(K)))
    scale = pyro.sample('scale', dist.LogNormal(0., 2.))
    with pyro.iarange('components', K):
        locs = pyro.sample('locs', dist.Normal(0., 10.))

    with pyro.iarange('data', len(data)):
        # Local variables.
        assignment = pyro.sample('assignment', dist.Categorical(weights))
        pyro.sample('obs', dist.Normal(locs[assignment], scale), obs=data)

global_guide = AutoDelta(poutine.block(model, expose=['weights', 'locs', 'scale']))
@config_enumerate(default="parallel")
@poutine.broadcast
def full_guide(data):
    # Global variables.
    with poutine.block(hide_types=["param"]):  # Keep our learned values of global parameters.
        global_guide(data)

    # Local variables.
    with pyro.iarange('data', len(data)):
        assignment_probs = pyro.param('assignment_probs', torch.ones(len(data), K) / K,
                                      constraint=constraints.unit_interval)
        pyro.sample('assignment', dist.Categorical(assignment_probs))

Ah I see the source of confusion. When @config_enumerate or infer={'enumerate': ...} appears in the model, it only applies if the variable has not already been enumerated in the guide. If a variable has been enumerated in the guide, then the model will simply replay the guide-enumerated variable (with corresponding nonstandard shape), and the ELBO will be

                       p(z4,x)
ELBO = sum q(z4|x) log -------
        z4             q(z4|x)

I’ll try to make this clear in the upcoming enumeration tutorial.