I am wondering what’s the best way to train a discriminator in pyro.
Assume I have a latent representation z
that a discriminator should not be able to assign to target 0 or 1 correctly.
As a test, I ran a simple linear regression as my discriminator model to predict the target.
y = beta_intercept + beta @ z
However, if the discriminator works well, it should increase the ELBO (in pyro I minimize the negative ELBO), but I cannot flip the scaling_factor
to be negative (throws an error: ValueError: Expected scale > 0 but got 0
).
with plates["samples"], pyro.poutine.scale(scale=scaling_factor):
pyro.sample(
"labels",
dist.Bernoulli(logits_to_probs(y)),
obs=target
)
Hence, I would like to understand what’s the go-to approach to integrate a discriminator into the loss?