Sampling during SVI with autodelta guide

Hello!
I’m trying to run a “supervised” HMM learning starting from the example in https://pyro.ai/examples/hmm.html.
I modified the code in order to learn from the dataset also the hidden states (it should be an easier problem than the original, since we don’t need to enumerate over the hidden states).
Here is my code:

def model_0(sequences_y, sequences_x, lengths, args, batch_size=None, include_prior=True):
    assert not torch._C._get_tracing_state()
    num_sequences, max_length, data_dim = sequences_y.shape[0],sequences_y.shape[1],sequences_y.shape[2]
    with poutine.mask(mask=include_prior):
        dirichlet_param = 48*torch.eye(args.hidden_dim) + 2
        probs_x = pyro.sample("probs_x", dist.Dirichlet(dirichlet_param).to_event(1))
        print(probs_x[0,:])
        beta_param_alpha = 46*torch.eye(args.hidden_dim,m=data_dim) + 2
        beta_param_beta = 50 - beta_param_alpha
        probs_y = pyro.sample("probs_y", dist.Beta(beta_param_alpha, beta_param_beta).to_event(2))
    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    for i in pyro.plate("sequences", len(sequences_y), batch_size):
        length = lengths[i]
        sequence_y = sequences_y[i, :length,:]
        sequence_x = sequences_x[i, :length]
        x = sequence_x[0]
        for t in pyro.markov(range(length)):
            x = pyro.sample("x_{}_{}".format(i, t), dist.Categorical(probs_x[x]), obs=sequence_x[t])
            with tones_plate:
                pyro.sample("y_{}_{}".format(i, t), dist.Bernoulli(probs_y[x.squeeze(-1)]), obs=sequence_y[t])

def main(args):
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')   
    data = utils.load_data(args.compute_df)    
    # data is a dictionary with 2 keys: test and train. Each key is associated to another dict with keys sequences and sequence_lengths
    # every subset is a matrix {numbert of pieces}x{max beat number}x89 where the last element in the last dimension is the key

    sequences_y = data['train']['sequences'][:,:,:-1]
    sequences_x = data['train']['sequences'][:,:,-1].long()
    lengths = data['train']['sequence_lengths']

    num_observations = float(lengths.sum())
    pyro.set_rng_seed(args.seed)
    pyro.clear_param_store()
    pyro.enable_validation(__debug__)

    guide = AutoDelta(poutine.block(model_0, expose=["probs_x","probs_y"]))

    optim = Adam({'lr': args.learning_rate})
    Elbo = JitTrace_ELBO if args.jit else Trace_ELBO
    elbo = Elbo(max_plate_nesting=1, jit_options={"time_compilation": args.time_compilation})
    svi = SVI(model_0, guide, optim, elbo)
    
    logging.info('Step\tLoss')
    for step in range(args.num_steps):
        loss = svi.step(sequences_y,sequences_x, lengths, args=args, batch_size=args.batch_size)

Now, I understood that with SVI the sampling statements in the model are sampling from the guide.
But from what distribution are they sampling exactly in this autoguide case?
Given the dirichlet parameter in the model, I would expect the first row of the sample of probs_x (printed in line 6 in the above code) to have a higher value at index 0.
But this is not the case, as this code will print this:
tensor([0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833, 0.0833,0.0833, 0.0833, 0.0833]).

Strangely enough this is not the case for the probs_y as it behaves as I would expected from a sample from a beta distribution with those parameters.

Thank you very much for your help!

It seems that the default init strategy init_to_median of AutoDelta does not apply to multivariate distributions. To get the expected behavior, you can switch to init_to_prior.

1 Like

Thank you! I would never have guessed something like this without your help. That seems like a strange default choice to me. I guess there is some numerical reason that makes this initialization complicated? :thinking:

Anyway, I cannot find the init_to_prior function you are talking about. Do you mean init_to_mean?
I’m a bit lost in the code structure and I can’t understand where this function is called and how to change it.

Oops, sorry, the name should be init_to_sample. You can find all initialization strategies here. You can set init_loc_fn in AutoDelta to any of those strategies.

1 Like

ohh thank you very much! This did work!