Relationship between guide_trace.log_prob_sum() and KL divergence?

I implemented a customized loss function for SVI according to Issues · pyro-ppl/pyro · GitHub

The loss function looks like:

def customized_loss(model, guide, *args, **kwargs):
    guide_trace = poutine.trace(guide).get_trace(*args)
    model_trace = poutine.trace(poutine.replay(model, trace=guide_trace)).get_trace(*args)

    model_log_p = model_trace.log_prob_sum()
    guide_log_p = guide_trace.log_prob_sum()

    addtional_term = get_additional_term(...)

    return -model_log_p + guide_log_p + additional_term

Basically, my loss function is:

My first question is: what’s the relationship between guide_lop_p and the KL divergence term in the loss function?

My second question is: Sometimes I observe guide_log_p to be a positive value, which does not make sense because log likelihood should always be a negative value. Is my understanding wrong? Or, if my understanding is correct, how could I debug in this situation?

Thanks a lot!

1 Like

a single-sample estimator of that KL divergence looks like

KL = log(q(z|x) / p(z))

with z ~ q(z|x)

so that’s exactly -model_log_p + guide_log_p contains (except it also contains additional log p(x|z) terms that enter into the expected log likelihood term).

as for your second question, for continuous random variables guide_log_p is actually the logarithm of a probability density and as such it can take on either sign.

1 Like

Thanks for your explanation! Did you mean that guide_log_p is the estimate of the expected log(q(z|x)) and model_log_p is the estimate of the expected log(p(x|z)) term plus the log(p(z)) term?

If this is the case, how can I calculate the value of the expected log likelihood, i.e., the first term in my loss function?

Thanks a lot!

1 Like

Thanks for your explanation! Did you mean that guide_log_p is the estimate of the expected log(q(z|x)) and model_log_p is the estimate of the expected log(p(x|z)) term plus the log(p(z)) term?

yes that’s right

If this is the case, how can I calculate the value of the expected log likelihood, i.e., the first term in my loss function?

you can pick out individual terms like so:

my_log_p = model_trace.nodes["my_observation"]["log_prob_sum"]```

Is there a way to isolate the KL divergence in this framework for instance to create a Beta-vae (https://arxiv.org/pdf/1804.03599.pdf) ?
i.e.

1 Like

you can use poutine.scale to scale the log probabilities of any terms you like, see e.g. https://github.com/uber/pyro/blob/dev/examples/dmm/dmm.py

1 Like