How to achieve variable scaling across multiple dimensional sites?

Hi there,
I’m trying to scale sampling at a site by the number of times the parameter would be sampled in the context of a minibatched MAP model. I’m wondering if it is possible to replace sampling the same sites many times with a scaled version to get the same ELBO with much less computation:

xinds=torch.randint(0, 4, (300,))

#Model that samples the exact same thing many times
def model(xinds):
    scale_tensor=torch.bincount(xinds, minlength=4)
    with pyro.plate(300):
        pyro.sample('param_with_cost_sample',dist.Laplace(torch.zeros(300,6,1000),torch.ones(300,6,1000)).to_event(2))

def guide(xinds):
    pwc=pyro.param('param_with_cost',torch.ones(4,6,1000))
    with pyro.plate(300):
        pyro.sample('param_with_cost_sample',dist.Delta(pwc[xinds]).to_event(2))

optim = pyro.optim.Adam({"lr": 0.1})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optim, loss=elbo)
loss = svi.step(xinds)

If xinds is very long however, this becomes very wasteful or impossible due to VRAM constraints. I’m wondering if there is a way to just sample the sites once, and then scale by the number of counts in the index, which I guess is sort of memoization…

#More efficient ("memoized") model
def model(xinds):
    scale_tensor=torch.bincount(xinds, minlength=4).view(-1,1,1,).expand(-1,6,1000)
    with poutine.scale(scale=scale_tensor):
        pyro.sample('param_with_cost_sample',dist.Laplace(torch.zeros(4,6,1000),torch.ones(4,6,1000)).to_event(3))


def guide(xinds):
    scale_tensor=torch.bincount(xinds, minlength=4).view(-1,1,1,).expand(-1,6,1000)
    print(scale_tensor.shape)
    pwc=pyro.param('param_with_cost',torch.ones(4,6,1000))
    with poutine.scale(scale=scale_tensor):
        pyro.sample('param_with_cost_sample',dist.Delta(pwc).to_event(3))

optim = pyro.optim.Adam({"lr": 0.1})
elbo = TraceEnum_ELBO(max_plate_nesting=1)
svi = SVI(model, guide, optim, loss=elbo)
loss = svi.step(xinds)

However I think multidimension scaling with poutine doesn’t work as this gives:

ValueError: at site "param_with_cost_sample", invalid log_prob shape
  Expected [], actual [1000]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Am I doing something wrong, or is there another way to achieve this? Thanks!

can you try to provide additional details? i’m afraid i don’t follow

in any case if xinds is constant you should only be computing torch.bincount(xinds, minlength=4) once outside the model

Sorry, I meant xinds is not constant but a random mini batch each step. My question is essentially: why can’t I give poutine.scale a tensor that’s the same shape as the sample, as it seems it can only be a single scalar?

This is just a simplified example of a single step that recreates the error!

it would seem that by using .to_event(3) instead of plates you’re basically erasing the tensor dimensions and thus making scaling by tensor-ial quantities impossible

Perhaps I’m underestimating the importance of plates… I was imagining these as being treated as a single event before but I guess they aren’t. So this seems to work, is it crazy?

import pyro
from pyro import poutine
import pyro.distributions as dist


def model(xinds):
    scale_tensor=torch.bincount(xinds, minlength=5).view(-1,1,1,).expand(-1,6,1000)+1e-10
    with pyro.plate(5), poutine.scale(scale=scale_tensor):
        with pyro.plate(6):
            with pyro.plate(1000):
                pyro.sample('param_with_cost_sample',dist.Laplace(torch.zeros(5,6,1000),torch.ones(5,6,1000)))


def guide(xinds):
    scale_tensor=torch.bincount(xinds, minlength=5).view(-1,1,1,).expand(-1,6,1000)+1e-10
    pwc=pyro.param('param_with_cost',torch.ones(5,6,1000))
    with pyro.plate(5), poutine.scale(scale=scale_tensor):
        with pyro.plate(6):
            with pyro.plate(1000):
                pyro.sample('param_with_cost_sample',dist.Delta(pwc))

pyro.clear_param_store()
optim = pyro.optim.Adam({"lr": 0.1})
elbo = pyro.infer.Trace_ELBO(max_plate_nesting=3)
svi = pyro.infer.SVI(model, guide, optim, loss=elbo)
for i in range(10):
    xinds=torch.randint(0, 4, (300,))
    loss = svi.step(xinds)
    print(loss)

I appreciate the insight! Thanks for your help!