Modeling a Gamma distributed censored event

I have some variable X ~ Gamma(shape, rate)

but the data I have, has it censored, so there is some value of Y (say 1.0) where if its true value is higher, the recorded observation is 1.0. I want to model this variable as a function based on a linear combination of some feature, X

The model I try to employ looks like this:

from scipy.stats import gamma as scipygamma

def model_gamma_cen(X, y, censored_label):


    min_value = torch.finfo(X.dtype).eps
    max_p_value = 1.0 - torch.finfo(X.dtype).eps
    max_value = torch.finfo(X.dtype).max/100.0


    # Prior on the intercept

    intercept_prior = dist.Normal(0.0, 1.0)
    linear_combination = pyro.sample(f"beta_intercept", intercept_prior)


    # Also define coefficient priors
    coefficients_prior = dist.Normal(0.0, 1.0).expand([X.shape[1]]) 
    betas = pyro.sample(f"beta_coefficients", coefficients_prior)


    # Finally, calculate the linear combination of parameters and X values
    linear_combination = linear_combination + torch.matmul(X, betas)


    # But now our mean will be e^{linear combination}, I'm using a link 
    # function log(mu)  = linear combo of X
    mean = torch.exp(linear_combination).clamp(min=min_value, max=max_value)

    # We will also define a rate parameter
    rate = pyro.sample("rate", dist.HalfCauchy(scale=10.0).clamp(min=min_value)

    # Since mean = shape/rate, then the shape = mean * rate
    shape = (mean * rate)

    # The data is now divided into two outcomes: censored, and non-censored
    rate_np = rate.detach().item()

    with pyro.plate("data", y.shape[0]):
        # If the data is not censored, we just observe the y value and it's
        # dictated by the Gamma distribution
        with pyro.poutine.mask(mask = (censored_label == 0.0)):
            observation = pyro.sample("obs", dist.Gamma(shape, rate), obs=y)

        # If the data is censored, we need to calculate the 1-CDF of the Gamma function
        # we will use the CDF implementation in scipy.stats.gamma
        
        with pyro.poutine.mask(mask = (censored_label == 1.0)):
            scipy_gamma_dist = scipygamma(a=shape.detach().numpy(), scale=1/rate.detach().numpy())
            
            # find truncation probability
            truncation_probability_np = 1.0 - scipy_gamma_dist.cdf(y.detach().numpy())
            truncation_prob = torch.tensor(truncation_probability_np).clamp(min=min_value, max=max_p_value)
    
            censored_observation = pyro.sample(f"censorship", dist.Bernoulli(truncation_prob), obs=torch.tensor(1.0)

Is there anything wrong with my censoring logic, or in my attempt to combine scipy.stats.gamma with Pyro?

i can’t follow what you’re doing without more details but in general using non-pyro distributions in a model/guide won’t work because (for example) gradients can’t flow to the parameters of the distribution (e.g. if you’re doing variational inference)

I’m using MCMC and NUTS, so I don’t think those depend on the gradients, do they?

EDIT: What more details should I provide?

HMC/NUTS uses gradients of the log density of the model to do inference so the specification of the model needs to be entirely in PyTorch

EDIT: What more details should I provide?

logic/overview of your model and what you’re trying to accomplish (more context than code)

Oh I see

What I’m trying to do is build a Generalized Linear Model on a Gamma distribution, which is easy enough. However, as I mentioned, some of my observations are censored. The data is the result of a survey on price so some prices were just listed as “$1,000+”, which means I don’t know if they were $1,000, or $2,000 or $5,000.

I’d like to make sure I’m getting proper coefficients for my linear model by exploring the possibility of censoring, so I’m trying to calculate the CDF for my Gamma distributed variable, given my observations. I took this approach from the old Pyro Tutorial about modelling censored events

as far as i can tell the general approach seems reasonable. although (i repeat for clarity) the detach statements will cause problems with HMC. the clamp statements might as well. also it might make sense to split your observations explicitly into censored/uncensored so that you don’t need the masks.