Customize gaussian process classification link function

Hi All,

I am new to pyro and I am trying to use pyro for gpc. I’d like to use some customized link functions e.g. cloglog or GEV link but I didn’t find out how to do it from those tutorials. Is it possible to do so using pyro? Thanks.

Hi @joehu, welcome! You can specify the link function using response_function argument in MultiClass likelihood. Currently, it uses _softmax link function

def _softmax(x):
    return F.softmax(x, dim=-1)

You can define response_function with learning parameters as a PyTorch nn.Module or Pyro nn.Module (i.e. GEV is a nn.Module with a parameter r). See how we use a deep neural network CNN in dkl example or nn.Linear layers as mean functions in deep GP tutorial.

You can either

  • use autoguide for the skewness parameter. I would recommend using this method unless you really want to customize the guide (rather than LogNormal that autoguide provides)
  • or wrap model, guide
class GEV_link(PyroModule):
    def __init__(self):
        super().__init__()
        self.skew = None
    ...

def model(...):
    skew = pyro.sample("skew", dist.InverseGamma(...))
    gpc.likelihood.response_function.skew = skew
    gpc.model()

def guide(...):
    skew = pyro.sample("skew", guide_for_skew)
    gpc.likelihood.response_function.skew = skew
    gpc.guide()

You can also use PyroParam instead of PyroSample for simplicity.

1 Like

The model and guide should have the same signature; and in svi.step(...) you need to provide arguments for model and guide (see SVI tutorial and the docs for svi.step). I’m not sure if there are other bugs.

SVI can handle the situation with function torch.where

I think it is fine. You can also use torch.clamp for that.

How about clamping x first?

x = torch.clamp(x, torch.finfo(x.dtype).eps)
return -torch.expm1(-torch.pow(x, r))

Hi @fehiepsi , now it works! Thanks a lot! I really appreciate your help!

I think the optimized parameter will have name skew_map unless you define Normal autoguide (in this case, the parameters will be skew_loc, skew_scale_unconstrained). Could you double check?

Oh, I see. Could you inherit GEV_link from pyro.contrib.gp.Parameterized instead of PyroModule? PyroModule does not support automatic autoguide, so you need to write the guide for skew parameter by yourself.

Thanks, it works!