Implementing HPF and some doubts in pyro.plate shapes

Hello. I am trying to implement this paper HPF- Gopalan et. al. in Pyro. So far i have a model that works reasonably well with an automatic guide. I am trying to write a custom guide where all my latent variables have a Gamma variational distribution. My model and guide look something like this .

@poutine.scale(scale = 1e-4)
def model(data, args):
    
    n_users, n_items = data.shape
    
    a_prime, b_prime, a, c_prime, d_prime, c, K = args
    
    
    user_plate = pyro.plate("user_plate", n_users)
    item_plate = pyro.plate("item_plate", n_items)
    comp_plate = pyro.plate("comp_plate", K)
    
    data_plate = pyro.plate("data_plate", size = ndata, subsample_size=sub_size, dim = -2)
    
    with user_plate:
        ξ_u = pyro.sample("ξ_u", dist.Gamma(a_prime, a_prime/b_prime))
    
    with item_plate:
        η_i = pyro.sample("η_i", dist.Gamma(c_prime, c_prime/d_prime))

    with comp_plate:
        θ_uk = pyro.sample("θ_uk", dist.Gamma(a, ξ_u).to_event(1))
        β_ik = pyro.sample("β_ik", dist.Gamma(c, η_i).to_event(1))

    with data_plate as ind:
        
        λ = torch.mm(θ_uk.T, β_ik)
        pyro.sample("y", dist.Poisson(λ[ind]).to_event(1), obs = data[ind])
        
@poutine.scale(scale = 1e-4)
def custom_guide(data, args):
    
    n_users, n_items = data.shape
    
    a_prime, b_prime, a, c_prime, d_prime, c, K = args
    
    Kint = int(K.item())
    
    
    user_plate = pyro.plate("user_plate", n_users, dim = -1, use_cuda=True)
    item_plate = pyro.plate("item_plate", n_items, dim = -1, use_cuda = True)
    comp_plate = pyro.plate("comp_plate", K, dim = -2, use_cuda = True)        
    
    β_rates, β_shapes  = 0.1*torch.ones((Kint, n_items)).to("cuda"), 0.1*torch.ones((Kint, n_items)).to("cuda")
    θ_rates, θ_shapes  = 0.1*torch.ones((Kint, n_users)).to("cuda"), 0.1*torch.ones((Kint, n_users)).to("cuda")
    
    
    β_dist = dist.Gamma(β_rates, β_shapes).to_event(1)
    θ_dist = dist.Gamma(θ_rates, θ_shapes).to_event(1)
    
    with item_plate:
        η_i = pyro.sample("η_i", dist.Gamma(a_prime, a_prime/b_prime))
           
    with user_plate:
        ξ_u = pyro.sample("ξ_u", dist.Gamma(c_prime, c_prime/d_prime))
    
    β_ik = pyro.sample("β_ik", β_dist)
    θ_uk = pyro.sample("θ_uk", θ_dist)
    

If i replace my custom guide with an automatic generated guide such as AutoDiagonalNormal. It works fine. Using my custom guide. I printed out the trace.format_shapes for both the model and guide.

guide trace

  Trace Shapes:            
   Param Sites:            
  Sample Sites:            
user_plate dist      |     
          value  999 |     
       log_prob      |     
item_plate dist      |     
          value 2000 |     
       log_prob      |     
comp_plate dist      |     
          value   40 |     
       log_prob      |     
       η_i dist 2000 |     
          value 2000 |     
       log_prob 2000 |     
       ξ_u dist  999 |     
          value  999 |     
       log_prob  999 |     
      β_ik dist   40 | 2000
          value   40 | 2000
       log_prob   40 |     
      θ_uk dist   40 |  999
          value   40 |  999
       log_prob   40 |     

model trace

  Trace Shapes:                
   Param Sites:                
  Sample Sites:                
user_plate dist          |     
          value      999 |     
       log_prob          |     
item_plate dist          |     
          value     2000 |     
       log_prob          |     
comp_plate dist          |     
          value       40 |     
       log_prob          |     
data_plate dist          |     
          value      100 |     
       log_prob          |     
       ξ_u dist      999 |     
          value      999 |     
       log_prob      999 |     
       η_i dist     2000 |     
          value     2000 |     
       log_prob     2000 |     
      θ_uk dist       40 |  999
          value       40 |  999
       log_prob       40 |     
      β_ik dist       40 | 2000
          value       40 | 2000
       log_prob       40 |     
         y dist 100  100 | 2000
          value      100 | 2000
       log_prob 100  100 |     

Even though they seem to agree i still get an error.

ValueError: at site "β_ik", invalid log_prob shape
  Expected [], actual [40]

Which is kind of confusing. It would be really great if i could get some input on this.

Your model guide has no pyro.param sites, so actually there is nothing to be trained by SVI. You can read the SVI tutorials to get the idea.

ah so i have to declare the variational parameters to be optimized as pyro.param. i forgot about that completely i always use an autoguide thanks i will check it out

@poutine.scale(scale = 1e-5)
def custom_guide(data, args):
    
    n_users, n_items = data.shape
    
    a_prime, b_prime, a, c_prime, d_prime, c, K = args
    
    Kint = int(K.item())
    
    ## declare plates first ##
   
    
    user_plate = pyro.plate("user_plate", n_users, dim = -1)
    item_plate = pyro.plate("item_plate", n_items, dim = -1)
    comp_plate = pyro.plate("comp_plate", K, dim = -1)
        
    ## declare the variational parameters to be optimized ##
    
    init_rates_ξ = torch.ones(n_users).to("cuda")
    init_shapes_ξ = torch.ones(n_users).to("cuda")
    
    init_rates_η =  torch.ones(n_items).to("cuda")
    init_shapes_η = torch.ones(n_items).to("cuda")
    
    init_rates_β = torch.ones((Kint, n_items)).to("cuda")
    init_shapes_β = torch.ones((Kint, n_items)).to("cuda")
    
    init_rates_θ  = torch.ones((Kint, n_users)).to("cuda")
    init_shapes_θ = torch.ones((Kint, n_users)).to("cuda")
    
    
    ξ_rates  = pyro.param("ξ_rates" , init_rates_ξ  ,  constraint =constraints.greater_than(0.01))
    ξ_shapes = pyro.param("ξ_shapes", init_shapes_ξ ,  constraint =constraints.greater_than(0.01))
    
    η_rates  = pyro.param("η_rates" , init_rates_η  , constraint =constraints.greater_than(0.01))
    η_shapes = pyro.param("η_shapes", init_shapes_η , constraint =constraints.greater_than(0.01))
    
    
    β_rates  = pyro.param("β_rates",  init_rates_β  , constraint =constraints.greater_than(0.01))
    β_shapes = pyro.param("β_shapes", init_shapes_β , constraint =constraints.greater_than(0.01))

    θ_rates  = pyro.param("θ_rates",  init_rates_θ  , constraint =constraints.greater_than(0.01))
    θ_shapes = pyro.param("θ_shapes", init_shapes_θ , constraint =constraints.greater_than(0.01))

    
    
    ## declare the variational distributions for approximating the posterior ##
    with user_plate:
    
        ξ_u = pyro.sample("ξ_u", dist.Gamma(ξ_rates, ξ_shapes))        
    
    with item_plate:
        
        η_i = pyro.sample("η_i", dist.Gamma(η_rates, η_shapes))

    with comp_plate:
        
        β_ik = pyro.sample("β_ik", dist.Gamma(β_rates, β_shapes).to_event(1))
        θ_uk = pyro.sample("θ_uk", dist.Gamma(θ_rates, θ_shapes).to_event(1))

    

this guide seems to work. thanks

But i kind of have another doubt about plates and shapes.

n_users = 10
a_prime = torch.tensor(0.1).to("cuda")
b_prime = torch.tensor(0.2).to("cuda")

init_rates_ξ = torch.ones(n_users).to("cuda")
init_shapes_ξ = torch.ones(n_users).to("cuda")

ξ_rates  = pyro.param("ξ_rates_" , init_rates_ξ  ,  constraint =constraints.greater_than(0.01))
ξ_shapes = pyro.param("ξ_shapes_", init_shapes_ξ ,  constraint =constraints.greater_than(0.01))
    
print(ξ_rates.shape, ξ_shapes.shape)
with pyro.plate("user_plate", n_users, dim = -1):
    ξ_ui = pyro.sample("ξ_ui", dist.Gamma(a_prime, a_prime/b_prime))
    ξ_u = pyro.sample("ξ_u", dist.Gamma(ξ_rates, ξ_shapes))        

print (ξ_ui.shape, ξ_u.shape)

If we look at this code snippet. Then ξ_ui should have shape 10 as a_prime and b_prime are both scalars and they get expanded out inside plate. But ξ_u should have shape 10x10 according to my understanding as they should be expanded out to plate dimensions and the initial shape of ξ_rates and ξ_shapes is 10. But both ξ_ui and ξ_u have shape 10 am i missing something

Because your ξ_ui and ξ_u have batch shape of [10], the user_plate make this dim independent, so the shape is still [10].
You could try to set the dim of user_plate to be -2. In this situation, plate will expand the shape to be [10,10].