I have a model for analyzing multi-stage decisions, included below:
def multi_stage_model(M, N, K, last_feature_idx_to_use, y, X, stage, beta_X=None):
# Priors for parameters
beta_X = numpyro.sample('beta_X', dist.Normal(0, 1).expand([N]))
# Priors for thresholds and deltas, ensuring they are ordered
logit_threshold_0 = numpyro.sample('logit_threshold_0', dist.Normal(-1, 1))
gap_0_1 = numpyro.sample('gap_0_1', dist.HalfNormal(jnp.sqrt(0.5)))
logit_thresholds_in_p_space = jnp.array([logit_threshold_0, logit_threshold_0 + gap_0_1])
thresholds_in_p_space = jax.scipy.special.expit(logit_thresholds_in_p_space)
deltas_0 = numpyro.sample('delta_0', dist.HalfNormal(jnp.sqrt(0.1)))
gap_delta = numpyro.sample('gap_delta', dist.HalfNormal(jnp.sqrt(0.3)))
deltas = jnp.array([deltas_0, deltas_0 + gap_delta])
# Sample deltas as separate named variables
delta_0 = numpyro.deterministic('delta_0_calculated', deltas[0])
delta_1 = numpyro.deterministic('delta_1_calculated', deltas[1])
# Loop over stages
for k in range(K):
# Compute phi for this stage using only the relevant features
phi = jax.scipy.special.expit(jnp.dot(X[:, :last_feature_idx_to_use[k]], beta_X[:last_feature_idx_to_use[k]]))
# Convert threshold from p-space to signal space for the current stage
log_ratio = jnp.log((phi / (1 - phi)) * ((1 - thresholds_in_p_space[k]) / thresholds_in_p_space[k]))
threshold_in_signal_space = (deltas[k]**2 - 2 * log_ratio) / (2 * deltas[k])
# Calculate probability mass above the threshold for the current stage
mass_above_threshold_pos = phi * (1 - jax.scipy.stats.norm.cdf(threshold_in_signal_space, loc=deltas[k], scale=1))
mass_above_threshold_neg = (1 - phi) * (1 - jax.scipy.stats.norm.cdf(threshold_in_signal_space, loc=0, scale=1))
search_rate = mass_above_threshold_pos + mass_above_threshold_neg
search_rate = jnp.clip(search_rate, 1e-6, 1-1e-6)
# Masking logic for the current stage
if k == 0:
stage_1_mask = (stage > 1) # Stage 1, only those who move on to stage 2 or beyond
numpyro.factor('stage_1_passing', jnp.sum(stage_1_mask * jnp.log(search_rate)))
numpyro.factor('stage_1_not_passing', jnp.sum(~stage_1_mask * jnp.log(1 - search_rate)))
else:
stage_passing_mask = (stage >= (k + 2)) # Those who moved on to the next stage
stage_not_passing_mask = (stage == k + 1) # Those who dropped out at this stage
# Calculate contributions only for those who passed the previous stages
numpyro.factor(f'stage_{k+1}_passing', jnp.sum(stage_passing_mask * jnp.log(search_rate)))
numpyro.factor(f'stage_{k+1}_not_passing', jnp.sum(stage_not_passing_mask * jnp.log(1 - search_rate)))
# Final outcome (y) for those who passed all stages
final_stage_mask = (stage == K + 1)
hit_rate = mass_above_threshold_pos / search_rate
hit_rate = jnp.clip(hit_rate, 1e-6, 1-1e-6)
with numpyro.handlers.mask(mask=final_stage_mask):
numpyro.sample('obs', dist.Bernoulli(probs=hit_rate), obs=y)
When I change the last line (numpyro.sample('obs', dist.Bernoulli(probs=hit_rate), obs=y)
) to numpyro.factor('final_outcome', jnp.sum(jnp.log(hit_rate) * y + jnp.log(1 - hit_rate) * (1 - y)))
and fit the model with MCMC on the same data, I get two very different convergences. Why is that? What is the difference between the sample code and the factor statement?