Strange KeyError Doing infer_discrete

Hi,

I want to simulate two agents communicate to each other using SVI, however my code gives me Keyerror:‘0_comm_type_0’, this is a sample name. I have been trapped here for a week, please help me. I have no idea what’s going on.

And this is my minimal code for reproducing the error:

import pyro
import torch
import logging
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.infer import config_enumerate, infer_discrete

class Agent:
    def __init__(self, id):
        self.id = id

    @config_enumerate
    def comm_model(self):
        for round in pyro.markov(range(1)):
            comm_type = pyro.sample('{}_comm_type_{}'.format(self.id, round), dist.Categorical(
                probs=torch.ones(2)), infer={'enumerate': None})
            if comm_type == torch.tensor(0):  
                comm = pyro.sample('{}_comm_state_{}'.format(self.id, round), dist.Categorical(
                    probs=torch.ones(2)))
            else:
                comm = pyro.sample(
                    '{}_comm_policy_{}'.format(self.id, round), dist.Categorical(probs=torch.ones(2)))

    @config_enumerate(default='parallel')
    def comm_guide(self):
        for round in pyro.markov(range(1)):
            comm_type_param = pyro.param(
                '{}_comm_type_param_{}'.format(self.id, round), torch.ones(2), constraint=constraints.positive)
            comm_type = pyro.sample('{}_comm_type_{}'.format(self.id, round), dist.Categorical(
                    probs=comm_type_param), infer={'enumerate': 'sequential'})

    def comm(self):
        adam = pyro.optim.Adam({'lr': 0.01})
        elbo = pyro.infer.TraceEnum_ELBO(num_particles=1)
        svi = pyro.infer.SVI(self.comm_model, self.comm_guide, adam, elbo)
        for step in range(10):
            loss = svi.step()

        guide_trace = pyro.poutine.trace(self.comm_guide).get_trace()
        trained_model = pyro.poutine.replay(self.comm_model, trace=guide_trace) 
        inferred_model = infer_discrete(trained_model, temperature=1,
                                        first_available_dim=-4)  # avoid conflict with data plate
        trace = pyro.poutine.trace(inferred_model).get_trace()

agents = {}
for id in range(2):
agents[id] = Agent(id=id)
for id in range(2):
    foci_agent = agents[id]
    foci_agent.comm()

Thanks!

I did some research, and found that if i comment out the if-else code, the error was gone.
I believe that the problem is caused by using the comm_type variable to determine downstream structure, however, as the tutorial instructed, i’m using sequential enumeration on this variable in the guide side which should work fine.
But where did i understand wrong?
Really need some help, thanks~

Hi @yaow, I see two issues with your model:

  1. there are no observation nodes so the model is simply a prior. To do Bayesian inference you’ll need a likelihood, i.e. pyro observe statements that look like
    pyro.sample("something", SomeDist(...), obs=some_data)
    
  2. Dynamic model structure is ok (e.g. your if-else block), but the guide should reflect the same dynamic structure in the model, so you’ll need a corresponding if-else block in the guide, something like
@config_enumerate(default='parallel')
def comm_guide(self):
    for round in pyro.markov(range(1)):
        comm_type_param = pyro.param(
            f'{self.id}_comm_type_param_{round}',
            torch.ones(2),
            constraint=constraints.positive,
        )
        comm_type = pyro.sample(
            f'{self.id}_comm_type_{round}',
            dist.Categorical(comm_type_param),
            infer={'enumerate': 'sequential'},
        )
        if comm_type == 0:  
            comm = pyro.sample(
                '{self.id}_comm_state_{round}',
                dist.Categorical(...learnable params...),
            )
        else:
            comm = pyro.sample(
                '{self.id}_comm_policy_{round}',
                dist.Categorical(...learnable params...),
            )

Thanks for the kindly reply.
1.In my original model, there are some observed sample site, i removed them one by one before the error disappear. So i guess this was not the cause of the error, but thanks for the remind~
2.According to my understanding, the sample sites inside the if-else block are all discrete distributions and are enumerated out in the model side, which makes the guide side sample site unnecessary, so i don’t see where i did wrong. But your code worked indeed, would you please explain a little bit more? Thanks!