Hi,
I want to simulate two agents communicate to each other using SVI, however my code gives me Keyerror:‘0_comm_type_0’, this is a sample name. I have been trapped here for a week, please help me. I have no idea what’s going on.
And this is my minimal code for reproducing the error:
import pyro
import torch
import logging
import pyro.distributions as dist
import pyro.distributions.constraints as constraints
from pyro.infer import config_enumerate, infer_discrete
class Agent:
def __init__(self, id):
self.id = id
@config_enumerate
def comm_model(self):
for round in pyro.markov(range(1)):
comm_type = pyro.sample('{}_comm_type_{}'.format(self.id, round), dist.Categorical(
probs=torch.ones(2)), infer={'enumerate': None})
if comm_type == torch.tensor(0):
comm = pyro.sample('{}_comm_state_{}'.format(self.id, round), dist.Categorical(
probs=torch.ones(2)))
else:
comm = pyro.sample(
'{}_comm_policy_{}'.format(self.id, round), dist.Categorical(probs=torch.ones(2)))
@config_enumerate(default='parallel')
def comm_guide(self):
for round in pyro.markov(range(1)):
comm_type_param = pyro.param(
'{}_comm_type_param_{}'.format(self.id, round), torch.ones(2), constraint=constraints.positive)
comm_type = pyro.sample('{}_comm_type_{}'.format(self.id, round), dist.Categorical(
probs=comm_type_param), infer={'enumerate': 'sequential'})
def comm(self):
adam = pyro.optim.Adam({'lr': 0.01})
elbo = pyro.infer.TraceEnum_ELBO(num_particles=1)
svi = pyro.infer.SVI(self.comm_model, self.comm_guide, adam, elbo)
for step in range(10):
loss = svi.step()
guide_trace = pyro.poutine.trace(self.comm_guide).get_trace()
trained_model = pyro.poutine.replay(self.comm_model, trace=guide_trace)
inferred_model = infer_discrete(trained_model, temperature=1,
first_available_dim=-4) # avoid conflict with data plate
trace = pyro.poutine.trace(inferred_model).get_trace()
agents = {}
for id in range(2):
agents[id] = Agent(id=id)
for id in range(2):
foci_agent = agents[id]
foci_agent.comm()
Thanks!