Return values from guide

Hello All,

I am looking at the DMM model.

After the model is trained and evaluated, I would like to get back the final learned z values(latent states) for both training and testing data from the guide. What is the correct way to do?

Things that I have tried:

I tried returning the values from the guide using the following code.

inside the for loop at the end

z_prev = z_t
if epoch == args.num_epochs-1 and trainortest=='train':
     my_train_store[t] = z_prev
if epoch == args.num_epochs-1 and trainortest=='test':
     my_test_store[t] = z_prev

outside the for loop

if epoch == args.num_epochs-1:
     if trainortest == 'test':
         return my_test_store
     elif trainortest == 'train':
         return my_train_store

But since the model and the guide were not called directly, as they were called from svi.step and svi.evaluate_loss, I am not able to return the values from the guide.

I tried extending pyro.infer.svi

In svi.step, I have added mystore to receive it from the guide.

def step(self, *args, **kwargs):
        with poutine.trace(param_only=True) as param_capture:
        loss, mystore = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

But I am getting the below error.

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-114-c4c1e8976169> in <module>()
     30     args = parser.parse_args()
     31 
---> 32     main(args)

2 frames
<ipython-input-106-af5ff8d16d55> in step(self, *args, **kwargs)
     76         with poutine.trace(param_only=True) as param_capture:
---> 77             loss, mystore = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
     78 
     79         params = set(site["value"].unconstrained()

TypeError: cannot unpack non-iterable float object

How do I return the value from self.loss_and_grads function?

Or is there any other easy way to get the learned z values from training and testing data?

Thanks in advance.

why don’t you just call the guide directly? it is a python function.