Correct ELBO Contribution of a Discriminator

I am wondering what’s the best way to train a discriminator in pyro.

Assume I have a latent representation z that a discriminator should not be able to assign to target 0 or 1 correctly.
As a test, I ran a simple linear regression as my discriminator model to predict the target.

y = beta_intercept + beta @ z

However, if the discriminator works well, it should increase the ELBO (in pyro I minimize the negative ELBO), but I cannot flip the scaling_factor to be negative (throws an error: ValueError: Expected scale > 0 but got 0).

with plates["samples"], pyro.poutine.scale(scale=scaling_factor):
  pyro.sample(
    "labels",
    dist.Bernoulli(logits_to_probs(y)),
    obs=target
)

Hence, I would like to understand what’s the go-to approach to integrate a discriminator into the loss?

i suggesting following the pattern described here

1 Like