Custom distribution for mcmc model

Hey,

I am trying to implement the model from this paper in pyro. I have a working pymc3 model (code at the bottom). My pyro code is below. Without the jit compiler the code seems to work but is very slow, 2 samples a second on my machine. With the jit compiler turned on, there are a bunch of warnings and the model does not converge. Any help would be most appreciated. Thanks.

Warnings
0/2000 [00:00, ?it/s]/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:28: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:33: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:33: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:34: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
/Users/ryany/opt/anaconda3/envs/pryo2/lib/python3.7/site-packages/ipykernel_launcher.py:34: TracerWarning: Converting a tensor to a Python float might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!

Pryo code:

import torch
import pyro
import pyro.distributions as dist
import pyro.infer.mcmc as mcmc
from pyro.infer import MCMC, NUTS


dat = torch.tensor([[100. ,  63.1,  46.8,  38.2,  32.6,  28.9,  26.2,  24.1],
                    [-1, 36.9, 16.300000000000004, 8.599999999999994, 5.600000000000001,
  3.700000000000003, 2.6999999999999993, 2.099999999999998]])

class sBG(pyro.distributions.Distribution):
    def __init__(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta
        
    def sample(self):
        return torch.tensor(1.)
    
    def log_prob(self, data):
        active = data[0,:]
        lost = data[1,:]
        
        n = active.shape[0]

        p = [0., self.alpha / (self.alpha + self.beta)]
        s = [0., 1 - p[1]]
        for t in range(2, n):
            pt = ((self.beta + t - 2) / (self.alpha + self.beta + t - 1)) * p[t-1]
            p.append(pt)
            s.append(s[t-1] - p[t])
            
        p = torch.tensor(p)
        s = torch.tensor(s)
        
        died = torch.mul(torch.log(p[1:]), lost[1:])
        still_active = torch.log(s[-1]) * active[-1]
        logp = torch.sum(died) + still_active
        return logp 
        
def model(data):
    alpha = pyro.sample('alpha', dist.Uniform(0, 10))
    beta = pyro.sample('beta' , dist.Uniform(0, 10))    
    return pyro.sample('p', sBG(alpha, beta), obs=data)
    
nuts_kernel = NUTS(model, jit_compile=False)
mcmc= MCMC(nuts_kernel, num_samples=100, warmup_steps=10)
mcmc.run(dat)

Pymc3 code (works):

def sBG_model(data):
    num_treatments = data['active'].shape[0]
    n = len(data['active'])

    with pm.Model() as model:

        # uniformative priors
        alpha = pm.Uniform('alpha', 0.00001, 10.0,  testval=1)
        beta = pm.Uniform('beta', 0.00001, 10.0, testval=1)

        ## defined in equation 7 in the paper
        p = [0., alpha / (alpha + beta)]
        s = [0., 1 - p[1]]
        for t in range(2, n):
            pt = ((beta + t - 2) / (alpha + beta + t - 1)) * p[t-1]
            p.append(pt)
            s.append(s[t-1] - p[t])

        # theano type conversion
        p = tt.stack(p)
        s = tt.stack(s)

        def logp(active, lost):
            # Those who've churned along the way...
            died = tt.mul(tt.log(p[1:]), lost[1:])

            # and those still active in last period
            still_active = tt.log(s[-1]) * active[-1]
            return  tt.sum(died) + still_active

        retention = pm.DensityDist('retention', logp, observed=data)

    return model

@youngre you would need to use torch.stack instead of torch.tensor (similar to theano code). For that, you also need to define initial values of p and s are torch.tensor(0.).

To skip jit_warning, you can use jit_compile=True, ignore_jit_warnings=True. The data is small so it would not be that slow (2 samples per second). Could you try the above fixes and let me know?

If the data is larger, you will want to use NumPyro. It would be very fast for this model. :slight_smile:

That works. Thanks for the help!