Is this setup for a manual logit-Normal guide correct?

Here’s a minimum working example:

def model(data): 
    mu = pyro.sample("mu", dist.Beta(1.,1.))
    y = pyro.sample( "y", dist.Binomial(total_count = 30,  probs = mu), obs = torch.tensor(27.))

def guide(data):
    loc = pyro.param('loc', lambda: torch.tensor(0.))
    scale = pyro.param('scale', lambda: torch.tensor(1.), constraint=constraints.positive)
    logit_mu = pyro.sample("logit_mu", dist.Normal(loc, scale), infer={'is_auxiliary': True})
    mu = pyro.sample('mu', dist.Delta(F.sigmoid( logit_mu )))
    return({"mu" : mu})

adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO() ) 
for j in range(300):
    loss = svi.step(None)

This runs and converges, but the results don’t look entirely consistent with using an AutoDiagonalNormal guide on mu (from my understanding it should be exactly equivalent). Is my setup missing a Jacobian term? Thanks!

yes looks like you’re missing a jacobian. this could be inputted as the log_density argument to Delta. is there a particular reason why you are doing things manually like this?

Thanks! So I think this is what I should be doing:

def log_sigmoid_deriv(x): 
    return F.logsigmoid( x ) + F.logsigmoid( -x )

def model(data): 
    mu = pyro.sample("mu", dist.Beta(1.,1.))
    y = pyro.sample( "y", dist.Binomial(total_count = 30,  probs = mu), obs = torch.tensor(27.))

def guide(data):
    loc = pyro.param('loc', lambda: torch.tensor(0.))
    scale = pyro.param('scale', lambda: torch.tensor(1.), constraint=constraints.positive)
    logit_mu = pyro.sample("logit_mu", dist.Normal(loc, scale), infer={'is_auxiliary': True})
    mu = pyro.sample('mu', dist.Delta(F.sigmoid( logit_mu ),
                                     log_density= -log_sigmoid_deriv(logit_mu)))
    return({"mu" : mu})

adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO() ) 
for j in range(300):
    loss = svi.step(None)

Re: manual. Yes, I want to set up a structured guide including some additional dependencies between sites rather than mean field with AutoDiagonalNormal.

well i think it’s probably better if you define a TransformedDistribution using a SigmoidTransform and avoid manual calculus. see e.g. here; git greping the repo is pretty useful for this sort of thing (at least once you know what to look for)


Thanks that looks like it would be cleaner (and less reliant on me doing the derivations correctly!)