Hi~
My code behaves strangely, but i have no idea where the error comes from, please help me.
the code is as follows
if you run the code, it will prompt an error:Size of label ‘d’ for operand 1 (4) does not match previous terms (2), however, if you comment out the ‘prompt’ variable, it will give a Assertion error.
I guess there are conflicts between the two markovs, but why?
Thanks~
# %%
import pyro
import torch
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate, infer_discrete, JitTraceEnum_ELBO
from pyro.ops.indexing import Vindex
class Agent:
def __init__(self,
id
):
self.id = id
@config_enumerate
def generative_model(self):
for T in pyro.markov(range(1)):
cor_level_policy = pyro.sample('{}_cor_level_policy_{}'.format(self.id, T), dist.Categorical(
probs=torch.ones(4)))
cp=pyro.deterministic('current_policy',cor_level_policy)
for round in pyro.markov(range(1)):
prompt = pyro.sample('{}_prompt_{}'.format(self.id, round), dist.Categorical(
probs=torch.ones(2)))
comm_type = pyro.sample('{}_comm_type_{}'.format(self.id, round), dist.Categorical(
probs=torch.ones(2)))
comm = pyro.sample('{}_comm_{}'.format(self.id, round), dist.Categorical(
probs=Vindex(torch.eye(4))[cp]))
def guide(self):
pass
def comm(self):
adam = pyro.optim.Adam({'lr': 0.01})
elbo = pyro.infer.TraceEnum_ELBO(num_particles=1)
svi = pyro.infer.SVI(self.generative_model, self.guide, adam, elbo)
losses = []
for step in range(10):
loss = svi.step()
logging.info("Elbo loss: {}".format(loss))
losses.append(loss)
agents = {}
for id in range(2):
agents[id] = Agent(id=id)
for id in range(2):
foci_agent = agents[id]
foci_agent.comm()