Multiple sample sites error — help?


#1

I’m trying to create a simple working example of a Laplace family guide as described in https://github.com/pyro-ppl/pyro/issues/1817. I’m getting an error RuntimeError: Multiple sample sites named 'sdhyper' which I don’t understand. Any clues would be appreciated. Here’s my code:

from __future__ import print_function

import torch
from torch.distributions import constraints
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import ClippedAdam
from pyro import poutine
import hessian
import numpy as np
ts = torch.tensor


torch.manual_seed(478301986) #Gingles

pyro.enable_validation(True)
pyro.set_rng_seed(0)


def model(G = 3, N = 4):
    sdhyper = pyro.sample('sdhyper', dist.Gumbel(0.,1.))
    gmeans = pyro.sample('gmeans', dist.StudentT(7.,0.,torch.exp(sdhyper)).expand([G]))
    gs = []
    for g in pyro.plate('groups', G):
        gs.append(pyro.sample(f'x_{g}',
                dist.Gumbel(gmeans[g],torch.exp(gmeans[(g+1)%G])).expand([N])))

BASE_PSI =.01

def infoToM(Info,psi):
    tlen = len(psi)
    M = torch.zeros(tlen,tlen)
    lseterms = torch.zeros(3)
    for i in range(tlen):
        lseterms[1] = -Info[i,i] + psi[i]
        lseterms[2] = -abs(Info[i,i]) + psi[i]
        for j in range(tlen):
            if j != i:
                lseterms[2] += abs(Info[i,j])
        M[i,i] = psi[i] * torch.logsumexp(lseterms / psi[i],0)
    return M

def guide(G = 3, N = 4):


    hat_data = dict()

    sdhyperhat = pyro.param('sdhyperhat', ts(0.))
    hat_data.update(sdhyper=sdhyperhat)
    gmeanshat = pyro.param('gmeanshat', torch.zeros(G))
    hat_data.update(gmeans=gmeanshat)
    gs = []
    for g in range(G):
        gs.append(pyro.param(f'xhat_{g}', torch.zeros(N)))
        hat_data.update({f'x_{g}':gs[g]})

    #Get hessian

    hessCenter = pyro.condition(model,hat_data)
    trace1 = poutine.trace(hessCenter)
    trace2 = trace1.get_trace() #*args,**kwargs)
    logPosterior = trace2.log_prob_sum()
    Info = -hessian.hessian(logPosterior, hat_data.values())#, allow_unused=True)

    thetaMean = torch.cat([thetaPart.view(-1) for thetaPart in hat_data.values()],0)
    tlen = len(thetaMean)

    #declare global-level psi params
    globalpsi = pyro.param('globalpsi',torch.ones(tlen)*BASE_PSI,
                constraint=constraints.positive)
    M = infoToM(Info,globalpsi)
    adjusted = Info+M
    print("matrix?",Info.size(),M.size(),[(float(Info[i,i]),float(M[i,i])) for i in range(tlen)])#,np.linalg.det(adjusted))
    theta = pyro.sample('theta',
                    dist.MultivariateNormal(thetaMean, precision_matrix=Info+M),
                    infer={'is_auxiliary': True})

    #decompose theta into specific values
    tmptheta = theta
    for pname, phat in hat_data.items():
        print(f"adding {pname} from theta" )
        elems = phat.nelement()
        pdat, tmptheta = tmptheta[:elems], tmptheta[elems:]
        pyro.sample(pname, dist.Delta(pdat.view(phat.size())))




def trainGuide():
    svi = SVI(model, guide, ClippedAdam({'lr': 0.005}), Trace_ELBO())

    pyro.clear_param_store()
    losses = []
    for i in range(3001):
        loss = svi.step()
        losses.append(loss)
        if i % 100 == 0:
            print(f'epoch {i} loss = {loss}')

    ##

    plt.plot(losses)
    plt.xlabel('epoch')
    plt.ylabel('loss')

    ##

    for (key, val) in sorted(pyro.get_param_store().items()):
        print(f"{key}:\n{val}")

#2

You’ll need to wrap your # Get hessian code block in with poutine.block(). The error is due to the .get_trace() call which exposes all pyro statements to enclosing handlers. Whenever you play a model inside a guide, you need to poutine.block() to prevent those statements from being seen. See e.g. use of poutine.block in pyro.contrib.autoguide.