Key Error: model and guide names do not match


#1

I have enclosed the following error description and two images of the model and guide function. I am not understanding what the key error means.

Just to give a background of what I am doing. I have my prior weights drawn from a categorical distribution and my variational posterior is also categorical

i am passing these sampled weights to a neural network for bayesian inference.


runfile(‘C:/Users/SRIKANTH R/.spyder-py3/pyro_bayesian_ternary.py’, wdir=‘C:/Users/SRIKANTH R/.spyder-py3’)

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py”, line 705, in runfile
execfile(filename, namespace)

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\spyder\utils\site\sitecustomize.py”, line 102, in execfile
exec(compile(f.read(), filename, ‘exec’), namespace)

File “C:/Users/SRIKANTH R/.spyder-py3/pyro_bayesian_ternary.py”, line 183, in
instance.do_inference()

File “C:/Users/SRIKANTH R/.spyder-py3/pyro_bayesian_ternary.py”, line 156, in do_inference
elbo=svi.step(images,labels)

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\pyro\infer\svi.py”, line 99, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\pyro\infer\trace_elbo.py”, line 126, in loss_and_grads
loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(model_trace, guide_trace)

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\pyro\infer\trace_elbo.py”, line 95, in _differentiable_loss_particle
log_r = _compute_log_r(model_trace, guide_trace)

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\pyro\infer\trace_elbo.py”, line 21, in _compute_log_r
log_r_term = log_r_term - guide_trace.nodes[name][“log_prob”]

File “C:\Users\SRIKANTH R\Anaconda3\lib\site-packages\networkx\classes\reportviews.py”, line 178, in getitem
return self._nodes[n]

KeyError: ‘cn1_prior_dist_sample’


#2

looks like in your model you called your sample cn1_prior_dist_sample but cn1_dist_sample in your guide. you should set pyro.enable_validation(True) to catch these errors for you.


#3

Thanks a lot. Will try it out