Customize gaussian process classification link function

Hi All,

I am new to pyro and I am trying to use pyro for gpc. I’d like to use some customized link functions e.g. cloglog or GEV link but I didn’t find out how to do it from those tutorials. Is it possible to do so using pyro? Thanks.

Hi @joehu, welcome! You can specify the link function using response_function argument in MultiClass likelihood. Currently, it uses _softmax link function

def _softmax(x):
    return F.softmax(x, dim=-1)

Hi @fehiepsi thanks for your quick reply! It really helps! However, a follow-up question is that for the GEV link, there is an additional parameter r adjusting the skewness of the link i.e. P(Y_i=1 | f_i, r) = 1 - GEV(- f_i; r), where GEV represents the cdf of generalized extreme value distribution. When the parameter r is fixed, I can implement the gpc easily using the response_function, but how can I approximate the value of r using pyro?

You can define response_function with learning parameters as a PyTorch nn.Module or Pyro nn.Module (i.e. GEV is a nn.Module with a parameter r). See how we use a deep neural network CNN in dkl example or nn.Linear layers as mean functions in deep GP tutorial.

Hi @fehiepsi , I defined a GEV as a PyroModule like this:

from pyro.nn import PyroModule, PyroParam, PyroSample
import pyro.distributions as dist

class GEV_link(PyroModule):
    def __init__(self, prior_alpha, prior_beta):
        super().__init__()
        self.skew = PyroSample(dist.InverseGamma(prior_alpha, prior_beta))
    
    def forward(self, x):
        return torch.where(x<=torch.tensor(0.),torch.tensor(0.0000001), 1. - torch.exp(-(x**self.skew)))

gev_link = GEV_link(2,2)

And then I try to fit the gpc:


# Choose kernel and likelihood
kernel = gp.kernels.RBF(input_dim = N_FEATURES, variance = torch.tensor(1.),lengthscale = torch.tensor(10.))
likelihood = gp.likelihoods.Binary(response_function=gev_link)
#likelihood = gp.likelihoods.Binary()
gpc = gp.models.VariationalGP(X_train,y_train,kernel=kernel,jitter = 1e-03, likelihood=likelihood,whiten = 
                             True)

optim = Adam({"lr":0.001})
svi = SVI(gpc.model, gpc.guide, optim, loss=Trace_ELBO())
num_steps = 10000 
losses =np.zeros(num_steps)
pyro.clear_param_store()
start =time.time()
for i in range(num_steps):
    losses[i]=svi.step()
    if i %(num_steps//20) ==0:
        print("iteration %d. Loss %.4f" % (i,losses[i]))
        elapsed_time = time.time()
        print("elapsed time: %.2f" %(elapsed_time-start))
end = time.time()

It gives error:
ValueError: The parameter loc has invalid values
Trace Shapes:
Param Sites:
f_loc 800
f_scale_tril 800 800
Sample Sites:
likelihood.response_function.skew dist |
value |

Do I need to specify a guide function for the skewness parameter r and how? Thanks for your help.

You can either

  • use autoguide for the skewness parameter. I would recommend using this method unless you really want to customize the guide (rather than LogNormal that autoguide provides)
  • or wrap model, guide
class GEV_link(PyroModule):
    def __init__(self):
        super().__init__()
        self.skew = None
    ...

def model(...):
    skew = pyro.sample("skew", dist.InverseGamma(...))
    gpc.likelihood.response_function.skew = skew
    gpc.model()

def guide(...):
    skew = pyro.sample("skew", guide_for_skew)
    gpc.likelihood.response_function.skew = skew
    gpc.guide()

You can also use PyroParam instead of PyroSample for simplicity.

1 Like

Hi @fehiepsi , thanks for your reply and I really appreciate your help. I tried to wrap the model like your suggestion:

class GEV_link(PyroModule):
    def __init__(self):
        super().__init__()
        self.skew = None
    def forward(self, x):
        return torch.where(x<=torch.tensor(0.),torch.tensor(0.), 1. - torch.exp(-(x**self.skew)))
    
gev_link = Gev_link()

kernel = gp.kernels.RBF(input_dim = N_FEATURES, variance = torch.tensor(1.),lengthscale = torch.tensor(10.))
likelihood = gp.likelihoods.Binary(response_function=gev_link)
gpc = gp.models.VariationalGP(X_train,y_train,kernel=kernel,jitter = 1e-03, likelihood=likelihood,whiten = 
                             True)
    
def model(prior_alpha, prior_beta):
    skew = pyro.sample("skew", dist.InverseGamma(prior_alpha, prior_beta))
    gpc.likelihood.response_function.skew = skew
    gpc.model()

def guide():
    a = pyro.param("a", torch.tensor(2.))
    b = pyro.param("b", torch.tensor(1.))
    skew = pyro.sample("skew", dist.InverseGamma(a, b))
    gpc.likelihood.response_function.skew = skew
    gpc.guide()

gpc_model = model(2,2)
gpc_guide = guide()

# Inference
optim = Adam({"lr":0.001})
svi = SVI(gpc_model,gpc_guide,optim,loss=Trace_ELBO())
num_steps = 1000
losses =np.zeros(num_steps)
pyro.clear_param_store()
start =time.time()
for i in range(num_steps):
    losses[i]=svi.step()
    if i %(num_steps//20) ==0:
        print("iteration %d. Loss %.4f" % (i,losses[i]))
        elapsed_time = time.time()
        print("elapsed time: %.2f" %(elapsed_time-start))
end = time.time()
print("Loop take time %.2f"%(end-start))
plt.plot(losses)
plt.show()

But it still doesn’t work, error like this:

/Users/hu/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/primitives.py:138: RuntimeWarning: trying to observe a value outside of inference at likelihood.y
  warnings.warn(
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-354-fe61d3e7b961> in <module>
     36 start =time.time()
     37 for i in range(num_steps):
---> 38     losses[i]=svi.step()
     39     if i %(num_steps//20) ==0:
     40         print("iteration %d. Loss %.4f" % (i,losses[i]))

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/infer/svi.py in step(self, *args, **kwargs)
    143         # get loss and compute gradients
    144         with poutine.trace(param_only=True) as param_capture:
--> 145             loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
    146 
    147         params = set(

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in loss_and_grads(self, model, guide, *args, **kwargs)
    138         loss = 0.0
    139         # grab a trace from the generator
--> 140         for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
    141             loss_particle, surrogate_loss_particle = self._differentiable_loss_particle(
    142                 model_trace, guide_trace

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/infer/elbo.py in _get_traces(self, model, guide, args, kwargs)
    184         else:
    185             for i in range(self.num_particles):
--> 186                 yield self._get_trace(model, guide, args, kwargs)

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/infer/trace_elbo.py in _get_trace(self, model, guide, args, kwargs)
     55         against it.
     56         """
---> 57         model_trace, guide_trace = get_importance_trace(
     58             "flat", self.max_plate_nesting, model, guide, args, kwargs
     59         )

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/infer/enum.py in get_importance_trace(graph_type, max_plate_nesting, model, guide, args, kwargs, detach)
     53     if detach:
     54         guide_trace.detach_()
---> 55     model_trace = poutine.trace(
     56         poutine.replay(model, trace=guide_trace), graph_type=graph_type
     57     ).get_trace(*args, **kwargs)

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
    196         Calls this poutine and returns its trace instead of the function's return value.
    197         """
--> 198         self(*args, **kwargs)
    199         return self.msngr.get_trace()

~/Downloads/anaconda/anaconda3/lib/python3.8/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
    172             )
    173             try:
--> 174                 ret = self.fn(*args, **kwargs)
    175             except (ValueError, RuntimeError) as e:
    176                 exc_type, exc_value, traceback = sys.exc_info()

TypeError: __call__() missing 1 required positional argument: 'fn'

For your first approach, I am not sure if I understand it correctly, I tried like this:

gpc_guide = AutoGuide(gpc)

which also doesn’t work. I am wondering if the SVI can handle the situation with function torch.where ? Again, really appreciate your time and help.

The model and guide should have the same signature; and in svi.step(...) you need to provide arguments for model and guide (see SVI tutorial and the docs for svi.step). I’m not sure if there are other bugs.

SVI can handle the situation with function torch.where

I think it is fine. You can also use torch.clamp for that.

Hi @fehiepsi, thanks for your reply. I am really sorry that I still cannot figure out the implementation. I correct the signature issue like this: First define the response_function:

class GEV_link(PyroModule):
    def __init__(self):
        super().__init__()
        self.skew = None        
    def forward(self, x):
        return torch.where(x<=torch.tensor(0.),torch.tensor(0.), 1. - torch.exp(-(x**self.skew)))

link = GEV_link()

Then wrap the model:

def model():
    skew = pyro.sample("skew", dist.InverseGamma(2.,1.))
    gpc.likelihood.response_function.skew = skew
    gpc.model()

def guide():
    a = pyro.param("a", torch.tensor(2.), constraint=constraints.positive)
    b = pyro.param("b", torch.tensor(1.), constraint=constraints.positive)
    skew = pyro.sample("skew", dist.InverseGamma(a, b))
    gpc.likelihood.response_function.skew = skew
    gpc.guide()

After the first svi.step, the error occurs:

ValueError: The parameter concentration has invalid values

So I checked the pyro.param like this:

for name in pyro.get_param_store().get_all_param_names():
    print(name, pyro.param(name).data.numpy())

The output shows all parameters have nan values:

a nan
b nan
f_loc [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
 nan nan nan nan nan nan nan nan]
f_scale_tril [[nan  0.  0. ...  0.  0.  0.]
 [nan nan  0. ...  0.  0.  0.]
 [nan nan nan ...  0.  0.  0.]
 ...
 [nan nan nan ... nan  0.  0.]
 [nan nan nan ... nan nan  0.]
 [nan nan nan ... nan nan nan]]
kernel.lengthscale nan
kernel.variance nan

I am not sure why this happens, because when I fix the skew parameter in response_function like this:

def link(x , r=3):
    return torch.where(x<=torch.tensor(0.),torch.tensor(0.), 1. - torch.exp(-(x**r)))

The model works well:

iteration 0. Loss 3754.3865
elapsed time: 0.22
iteration 50. Loss 3709.1721
elapsed time: 9.23
iteration 100. Loss 3573.4878
elapsed time: 18.19
iteration 150. Loss 2879.5670
elapsed time: 27.15
iteration 200. Loss 2845.0823
elapsed time: 36.07

Hi @fehiepsi , I just found that when the skew parameter r is not integer the error also happens, even the response_function is fixed like this:

def link(x , r=3.8):
    return torch.where(x<=torch.tensor(0.),torch.tensor(0.), 1. - torch.exp(-(torch.pow(x, r))))

Although this link work reasonably:

link(torch.tensor([0.,0.1,0.5,0.9,2.]))

has result

tensor([0.0000e+00, 1.5849e-04, 6.9277e-02, 4.8833e-01, 1.0000e+00])

Do you have any idea why this response_function only works when skew parameter r is integer?

How about clamping x first?

x = torch.clamp(x, torch.finfo(x.dtype).eps)
return -torch.expm1(-torch.pow(x, r))

Hi @fehiepsi , now it works! Thanks a lot! I really appreciate your help!

Hi @fehiepsi , I tried to specify link skewness as a PyroParam like this and optimize with SVI:

class GEV_link(PyroModule):
    def __init__(self):
        super().__init__()
        self.skew = None
        
    def forward(self, x):
        x = torch.clamp(x, torch.finfo(x.dtype).eps)
        return -torch.expm1(-torch.pow(x, self.skew)) 

gpc.likelihood.response_function.skew = pyro.nn.PyroParam(torch.tensor(1.), constraint = constraints.positive)
svi = SVI(gpc.model,gpc.guide,optim,loss=Trace_ELBO())

which works and when I run

gpc.named_parameters()

it shows the skew as a parameter

likelihood.response_function.skew_unconstrained': Parameter containing:
tensor(0., requires_grad=True)}

However, when I tried to specify skewness using PyroSample:

gpc.likelihood.response_function.skew = pyro.nn.PyroSample(dist.LogNormal(0.0, 1.0))

It seems that the skewness parameter was not optimized and not inside gpc.named_paramters(). Is there something wrong with my code? How should I specify a prior for skewness parameter ? Thanks.

I think the optimized parameter will have name skew_map unless you define Normal autoguide (in this case, the parameters will be skew_loc, skew_scale_unconstrained). Could you double check?

Hi @fehiepsi , thanks for your reply, I specified the model like this:

kernel = gp.kernels.RBF(input_dim = 1, variance = torch.tensor(1.),lengthscale = torch.tensor(1.))
likelihood = gp.likelihoods.Binary(response_function=link)
gpc = gp.models.VariationalGP(X_train,y_train,kernel=kernel,jitter = 1e-03, likelihood=likelihood,whiten = 
                             True)
gpc.kernel.lengthscale = pyro.nn.PyroSample(dist.LogNormal(0.0, 1.0))
gpc.kernel.variance = pyro.nn.PyroSample(dist.LogNormal(0.0, 1.0))
gpc.likelihood.response_function.skew = pyro.nn.PyroSample(dist.LogNormal(0.0, 1.0))
svi = SVI(gpc.model,gpc.guide,optim,loss=Trace_ELBO())

The optimization can be done without error message, but the skew parameter doesn’t show up:

dict(gpc.named_parameters())

has output:

{'f_loc': Parameter containing:
 tensor([ 5.5759e-01,  8.6001e-01,  2.9717e-01,  2.7139e-01,  9.4732e-01,
         -6.3592e-01, -1.5857e-01,  9.3623e-01,  5.0435e-01, -1.9287e-02,
          ...
          7.9493e-02, -1.0671e-02, -1.4016e-03,  9.0551e-02,  1.1223e-01],
        requires_grad=True),
 'f_scale_tril_unconstrained': Parameter containing:
 tensor([[-9.8989e-01,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...
         [-1.5034e-02,  1.9787e-03,  4.2364e-03,  ...,  4.6530e-03,
           0.0000e+00,  0.0000e+00],
         [-2.7375e-02, -7.4072e-04, -1.7979e-02,  ..., -1.1488e-02,
          -1.0711e-02,  0.0000e+00],
         [-1.4268e-02, -3.7748e-02,  1.9061e-02,  ..., -2.3025e-02,
          -1.1589e-02,  5.2974e-03]], requires_grad=True),
 'kernel.lengthscale_map_unconstrained': Parameter containing:
 tensor(0.5843, requires_grad=True),
 'kernel.variance_map_unconstrained': Parameter containing:
 tensor(-0.5299, requires_grad=True)}

I also tried:

for name in pyro.get_param_store().get_all_param_names():
    print(name, pyro.param(name).data.numpy())

which doesn’t include skew parameter neither:

kernel.lengthscale_map 1.7936798
kernel.variance_map 0.5886754
f_loc [ 5.57592869e-01  8.60007167e-01  2.97169298e-01  2.71389902e-01
...
-1.83296129e-02 -1.74921509e-02 -1.50805945e-02  7.94932544e-02
 -1.06709376e-02 -1.40158192e-03  9.05509517e-02  1.12234637e-01]
f_scale_tril [[ 3.7161931e-01  0.0000000e+00  0.0000000e+00 ...  0.0000000e+00
   0.0000000e+00  0.0000000e+00]
 ...
 [-1.5033977e-02  1.9787024e-03  4.2363820e-03 ...  1.0046638e+00
   0.0000000e+00  0.0000000e+00]
 [-1.4267719e-02 -3.7748434e-02  1.9061258e-02 ... -2.3025490e-02
  -1.1588537e-02  1.0053115e+00]]

Oh, I see. Could you inherit GEV_link from pyro.contrib.gp.Parameterized instead of PyroModule? PyroModule does not support automatic autoguide, so you need to write the guide for skew parameter by yourself.

Thanks, it works!