I have some variable X ~ Gamma(shape, rate)
but the data I have, has it censored, so there is some value of Y (say 1.0) where if its true value is higher, the recorded observation is 1.0. I want to model this variable as a function based on a linear combination of some feature, X
The model I try to employ looks like this:
from scipy.stats import gamma as scipygamma
def model_gamma_cen(X, y, censored_label):
min_value = torch.finfo(X.dtype).eps
max_p_value = 1.0 - torch.finfo(X.dtype).eps
max_value = torch.finfo(X.dtype).max/100.0
# Prior on the intercept
intercept_prior = dist.Normal(0.0, 1.0)
linear_combination = pyro.sample(f"beta_intercept", intercept_prior)
# Also define coefficient priors
coefficients_prior = dist.Normal(0.0, 1.0).expand([X.shape[1]])
betas = pyro.sample(f"beta_coefficients", coefficients_prior)
# Finally, calculate the linear combination of parameters and X values
linear_combination = linear_combination + torch.matmul(X, betas)
# But now our mean will be e^{linear combination}, I'm using a link
# function log(mu) = linear combo of X
mean = torch.exp(linear_combination).clamp(min=min_value, max=max_value)
# We will also define a rate parameter
rate = pyro.sample("rate", dist.HalfCauchy(scale=10.0).clamp(min=min_value)
# Since mean = shape/rate, then the shape = mean * rate
shape = (mean * rate)
# The data is now divided into two outcomes: censored, and non-censored
rate_np = rate.detach().item()
with pyro.plate("data", y.shape[0]):
# If the data is not censored, we just observe the y value and it's
# dictated by the Gamma distribution
with pyro.poutine.mask(mask = (censored_label == 0.0)):
observation = pyro.sample("obs", dist.Gamma(shape, rate), obs=y)
# If the data is censored, we need to calculate the 1-CDF of the Gamma function
# we will use the CDF implementation in scipy.stats.gamma
with pyro.poutine.mask(mask = (censored_label == 1.0)):
scipy_gamma_dist = scipygamma(a=shape.detach().numpy(), scale=1/rate.detach().numpy())
# find truncation probability
truncation_probability_np = 1.0 - scipy_gamma_dist.cdf(y.detach().numpy())
truncation_prob = torch.tensor(truncation_probability_np).clamp(min=min_value, max=max_p_value)
censored_observation = pyro.sample(f"censorship", dist.Bernoulli(truncation_prob), obs=torch.tensor(1.0)
Is there anything wrong with my censoring logic, or in my attempt to combine scipy.stats.gamma
with Pyro?