Understanding factor vs sample

I have a model for analyzing multi-stage decisions, included below:

def multi_stage_model(M, N, K, last_feature_idx_to_use, y, X, stage, beta_X=None):
    # Priors for parameters
    beta_X = numpyro.sample('beta_X', dist.Normal(0, 1).expand([N]))
    
    # Priors for thresholds and deltas, ensuring they are ordered
    logit_threshold_0 = numpyro.sample('logit_threshold_0', dist.Normal(-1, 1))
    gap_0_1 = numpyro.sample('gap_0_1', dist.HalfNormal(jnp.sqrt(0.5)))
    logit_thresholds_in_p_space = jnp.array([logit_threshold_0, logit_threshold_0 + gap_0_1])

    thresholds_in_p_space = jax.scipy.special.expit(logit_thresholds_in_p_space)

    deltas_0 = numpyro.sample('delta_0', dist.HalfNormal(jnp.sqrt(0.1)))
    gap_delta = numpyro.sample('gap_delta', dist.HalfNormal(jnp.sqrt(0.3)))
    deltas = jnp.array([deltas_0, deltas_0 + gap_delta])

    # Sample deltas as separate named variables
    delta_0 = numpyro.deterministic('delta_0_calculated', deltas[0])
    delta_1 = numpyro.deterministic('delta_1_calculated', deltas[1])

    
    # Loop over stages
    for k in range(K):
        # Compute phi for this stage using only the relevant features
        phi = jax.scipy.special.expit(jnp.dot(X[:, :last_feature_idx_to_use[k]], beta_X[:last_feature_idx_to_use[k]]))
        
        # Convert threshold from p-space to signal space for the current stage
        log_ratio = jnp.log((phi / (1 - phi)) * ((1 - thresholds_in_p_space[k]) / thresholds_in_p_space[k]))
        threshold_in_signal_space = (deltas[k]**2 - 2 * log_ratio) / (2 * deltas[k])
        
        # Calculate probability mass above the threshold for the current stage
        mass_above_threshold_pos = phi * (1 - jax.scipy.stats.norm.cdf(threshold_in_signal_space, loc=deltas[k], scale=1))
        mass_above_threshold_neg = (1 - phi) * (1 - jax.scipy.stats.norm.cdf(threshold_in_signal_space, loc=0, scale=1))
        search_rate = mass_above_threshold_pos + mass_above_threshold_neg
        search_rate = jnp.clip(search_rate, 1e-6, 1-1e-6)
        
        # Masking logic for the current stage
        if k == 0:
            stage_1_mask = (stage > 1)  # Stage 1, only those who move on to stage 2 or beyond
            numpyro.factor('stage_1_passing', jnp.sum(stage_1_mask * jnp.log(search_rate)))
            numpyro.factor('stage_1_not_passing', jnp.sum(~stage_1_mask * jnp.log(1 - search_rate)))

            
        else:
            stage_passing_mask = (stage >= (k + 2))  # Those who moved on to the next stage
            stage_not_passing_mask = (stage == k + 1)  # Those who dropped out at this stage

            # Calculate contributions only for those who passed the previous stages
            numpyro.factor(f'stage_{k+1}_passing', jnp.sum(stage_passing_mask * jnp.log(search_rate)))
            numpyro.factor(f'stage_{k+1}_not_passing', jnp.sum(stage_not_passing_mask * jnp.log(1 - search_rate)))

    # Final outcome (y) for those who passed all stages
    final_stage_mask = (stage == K + 1)
    hit_rate = mass_above_threshold_pos / search_rate
    hit_rate = jnp.clip(hit_rate, 1e-6, 1-1e-6)
    
    with numpyro.handlers.mask(mask=final_stage_mask):
        numpyro.sample('obs', dist.Bernoulli(probs=hit_rate), obs=y)

When I change the last line (numpyro.sample('obs', dist.Bernoulli(probs=hit_rate), obs=y) ) to numpyro.factor('final_outcome', jnp.sum(jnp.log(hit_rate) * y + jnp.log(1 - hit_rate) * (1 - y))) and fit the model with MCMC on the same data, I get two very different convergences. Why is that? What is the difference between the sample code and the factor statement?

probably you are summing over too many dimensions in the sum and broadcasting the factor up?

Sorry, what do you mean by “summing over too many dimensions in the sum and broadcasting the factor up”? Like what precisely is the difference between the two versions here in terms of what happens under the hood?

For reference, the numpyro.sample version converges to the correct parameters more often than the numpyro.factor version.

https://pyro.ai/examples/tensor_shapes.html

Sorry, I don’t see in your link above where it explains how factor statements are batched/broadcasted. Could you explain what the difference between the two versions of the code is more specifically?

i don’t have time to look at your code in detail but i don’t think you want the sum because your factor statement is inside a plate and so the factor log_prob should have a dimension that matches the size of the plate.

Sorry, I’m still not quite sure I understand. The factor statement is not inside a plate. It is inside a numpyro.handlers.mask line. Do these mask statements implicitly make a plate in the backend? I didn’t say anything on the effect handlers page that seemed to indicate as such (Effect Handlers — NumPyro documentation).

like i said i don’t really have time to look at your code. maybe there’s no plate. that’s fine.

maybe it’s a numeric issue then e.g. maybe you need jnp.log1p

Sorry, I am not asking for a look at the code, and I appreciate your help thus far. I’m trying to understand what the difference between factor and sample with obs is (an article explaining the difference would be sufficient, but I have not seen anything in the tutorials, e.g. the original tutorial linked explained what sample does but did not contrast it with factor statements).

i suggest looking at the code: pyro/pyro/primitives.py at 455f7b3b8b21f8e93a96235fc6bd58cb60f8a3fa · pyro-ppl/pyro · GitHub

factor is basically a particular kind of sample statement with a provided log_prob for the observation.