Hi there,
I’m trying to scale sampling at a site by the number of times the parameter would be sampled in the context of a minibatched MAP model. I’m wondering if it is possible to replace sampling the same sites many times with a scaled version to get the same ELBO with much less computation:
xinds=torch.randint(0, 4, (300,))
#Model that samples the exact same thing many times
def model(xinds):
scale_tensor=torch.bincount(xinds, minlength=4)
with pyro.plate(300):
pyro.sample('param_with_cost_sample',dist.Laplace(torch.zeros(300,6,1000),torch.ones(300,6,1000)).to_event(2))
def guide(xinds):
pwc=pyro.param('param_with_cost',torch.ones(4,6,1000))
with pyro.plate(300):
pyro.sample('param_with_cost_sample',dist.Delta(pwc[xinds]).to_event(2))
optim = pyro.optim.Adam({"lr": 0.1})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optim, loss=elbo)
loss = svi.step(xinds)
If xinds is very long however, this becomes very wasteful or impossible due to VRAM constraints. I’m wondering if there is a way to just sample the sites once, and then scale by the number of counts in the index, which I guess is sort of memoization…
#More efficient ("memoized") model
def model(xinds):
scale_tensor=torch.bincount(xinds, minlength=4).view(-1,1,1,).expand(-1,6,1000)
with poutine.scale(scale=scale_tensor):
pyro.sample('param_with_cost_sample',dist.Laplace(torch.zeros(4,6,1000),torch.ones(4,6,1000)).to_event(3))
def guide(xinds):
scale_tensor=torch.bincount(xinds, minlength=4).view(-1,1,1,).expand(-1,6,1000)
print(scale_tensor.shape)
pwc=pyro.param('param_with_cost',torch.ones(4,6,1000))
with poutine.scale(scale=scale_tensor):
pyro.sample('param_with_cost_sample',dist.Delta(pwc).to_event(3))
optim = pyro.optim.Adam({"lr": 0.1})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optim, loss=elbo)
loss = svi.step(xinds)
However I think multidimension scaling with poutine doesn’t work as this gives:
ValueError: at site "param_with_cost_sample", invalid log_prob shape
Expected [], actual [1000]
Try one of the following fixes:
- enclose the batched tensor in a with pyro.plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
Am I doing something wrong, or is there another way to achieve this? Thanks!