How to model parametric trend + GP residuals?


#1

Hi @fehiepsi and all,

I’d like to model a time series as a parametric trend + a nonparametric residual. Currently I’m using scipy.optimize.curve_fit to model the trend and using pyro.contrib.gp to model residuals wrt this trend. Is there a way to do this all in Pyro, say using syntax like

residual_model = pyro.gp.models.GPRegression(...)  # <--- not quite right...

def model(x, y):
    a = pyro.param("a", torch.tensor(0.0))
    b = pyro.param("b", torch.tensor(1.0))
    trend = a * x + b
    residual = y - trend
    pyro.sample("residual", residual_model, obs=residual)  # <--- desired syntax

#2

Hi @fritzo, for this particular model, it will be achieved if we support mean_function for GP models (currently GP module in Pyro assumes mean_function is 0). I will implement it soon if you only need this way.

About combining GP model with another model (to train at the same time), I guess we can use

def model(x, y):
    a = pyro.param("a", torch.tensor(0.0))
    b = pyro.param("b", torch.tensor(1.0))
    trend = a * x + b
    residual = y - trend
    residual_model.set_data(x, residual)
    residual_model.model()

def guide(x, y):
    residual_model.guide()

For the syntax pyro.sample("residual", residual_model, obs=residual), did you want to make your GP model a distribution? I can’t find a good solution for it. We might be able to do so by separating the likelihood from .model(), then the syntax will like

f = residual_model.model(x)
pyro.sample("residual", residual_model.likelihood(f), obs=residual)

If you want this way, I will think how to design again the current GP module.


#3

@fehiepsi your (model, guide) pair looks great, as long as I can .set_data() at each SVI.step(), and as long as gradients are correctly propagated through .set_data(). Can you confirm that this is the case?


#4

Hi @fritzo , I made a simple gist based on your model. It seems that the model can learn

Could you give me some ideas to check if gradients are correctly propagated?

Ah, you need to use Pyro dev, it just supports batch training recently.