Is this setup for a manual logit-Normal guide correct?

Here’s a minimum working example:

def model(data): 
    mu = pyro.sample("mu", dist.Beta(1.,1.))
    y = pyro.sample( "y", dist.Binomial(total_count = 30,  probs = mu), obs = torch.tensor(27.))

def guide(data):
    loc = pyro.param('loc', lambda: torch.tensor(0.))
    scale = pyro.param('scale', lambda: torch.tensor(1.), constraint=constraints.positive)
    logit_mu = pyro.sample("logit_mu", dist.Normal(loc, scale), infer={'is_auxiliary': True})
    mu = pyro.sample('mu', dist.Delta(F.sigmoid( logit_mu )))
    return({"mu" : mu})

pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO() ) 
losses=[]
for j in range(300):
    loss = svi.step(None)
    losses.append(loss)
plt.plot(losses)

This runs and converges, but the results don’t look entirely consistent with using an AutoDiagonalNormal guide on mu (from my understanding it should be exactly equivalent). Is my setup missing a Jacobian term? Thanks!

1 Like

yes looks like you’re missing a jacobian. this could be inputted as the log_density argument to Delta. is there a particular reason why you are doing things manually like this?

1 Like

Thanks! So I think this is what I should be doing:

def log_sigmoid_deriv(x): 
    return F.logsigmoid( x ) + F.logsigmoid( -x )

def model(data): 
    mu = pyro.sample("mu", dist.Beta(1.,1.))
    y = pyro.sample( "y", dist.Binomial(total_count = 30,  probs = mu), obs = torch.tensor(27.))

def guide(data):
    loc = pyro.param('loc', lambda: torch.tensor(0.))
    scale = pyro.param('scale', lambda: torch.tensor(1.), constraint=constraints.positive)
    logit_mu = pyro.sample("logit_mu", dist.Normal(loc, scale), infer={'is_auxiliary': True})
    mu = pyro.sample('mu', dist.Delta(F.sigmoid( logit_mu ),
                                     log_density= -log_sigmoid_deriv(logit_mu)))
    return({"mu" : mu})

pyro.clear_param_store()
adam = pyro.optim.Adam({"lr": 0.03})
svi = SVI(model, guide, adam, loss=Trace_ELBO() ) 
losses=[]
for j in range(300):
    loss = svi.step(None)
    losses.append(loss)
plt.plot(losses)

Re: manual. Yes, I want to set up a structured guide including some additional dependencies between sites rather than mean field with AutoDiagonalNormal.

well i think it’s probably better if you define a TransformedDistribution using a SigmoidTransform and avoid manual calculus. see e.g. here; git greping the repo is pretty useful for this sort of thing (at least once you know what to look for)

2 Likes

Thanks that looks like it would be cleaner (and less reliant on me doing the derivations correctly!)