How to integrate pyro model to c++

Is there any c++ API in pyro? There is a model (several .p files) trained in pyro. I want to use it for prediction in c++. What should I do?

Thanks

Hi @sj2019, great question!

tl;dr torch.jit.save your trained model from Python, then torch::jit::load() it into C++.

My team is also exploring ways to serve Pyro models in C++. While I haven’t used the torch.jit.save() ; torch::jit::load() path myself, I’ve used torch.jit.trace() to speed up models. If you do this, I’d recommend using PyTorch modules and pyro.module() rather than directly storing params in the param store via pyro.param(), because PyTorch modules are a bit easier to torch.jit.save(). I believe a pattern that should work is (again YMMV as I haven’t tried this):

class Model(nn.Module):
    ...
    def forward(self, ...):
        ...

class Guide(nn.Module):
    ...
    def forward(self, ...):
        ...

@torch.jit.ScriptModule
class Predictor(nn.Module):
    def __init__(self, model, guide):
        super().__init__()
        self.model = model
        self.guide = guide

    # This might need to be torch.jit.traced, I'm not sure.
    def forward(self, data):
        tr = poutine.trace(self.guide).get_trace(data)
        return poutine.replay(self.model, tr)(data)

...train using SVI...

# Finally save the trained predictor.
predictor = Predictor(...)
torch.jit.save(predictor, "predictor.torch")

Then I believe you can load your model with something like:

auto predictor = torch::jit::load("predictor.torch")

Let me know how this goes. We could definitely use a tutorial on how to do this, so we welcome feedback about anything you learn. I’ve added this question as a request for tutorial: https://github.com/pyro-ppl/pyro/issues/2054

Cheers,
Fritz