I am wondering what’s the best way to train a discriminator in pyro.
Assume I have a latent representation
z that a discriminator should not be able to assign to target 0 or 1 correctly.
As a test, I ran a simple linear regression as my discriminator model to predict the target.
y = beta_intercept + beta @ z
However, if the discriminator works well, it should increase the ELBO (in pyro I minimize the negative ELBO), but I cannot flip the
scaling_factor to be negative (throws an error:
ValueError: Expected scale > 0 but got 0).
with plates["samples"], pyro.poutine.scale(scale=scaling_factor): pyro.sample( "labels", dist.Bernoulli(logits_to_probs(y)), obs=target )
Hence, I would like to understand what’s the go-to approach to integrate a discriminator into the loss?