Carlin et al. Smoking Longitudinal Binary Mixture Model

Hi all,

Can anyone help me with implementing the discrete mixture model shown in A case study on the choice, interpretation and checking of multilevel models for longitudinal binary outcomes?. I do not know how to do the mixture model detailed in the text (screenshot below)

Below is my attempt but this obviously does not work.

def discrete_mixture(
    id_individual: DeviceArray,
    parsmk: DeviceArray,
    sex: DeviceArray,
    wave: DeviceArray,
    y: DeviceArray = None,
):
    num_individual = np.unique(id_individual).shape[0]
    num_data = parsmk.shape[0]

    sigma = numpyro.sample("sigma", dist.Exponential(1))

    with numpyro.plate("individual_plate", num_individual):
        gamma = numpyro.sample("gamma", dist.Normal(0, sigma))

    phi0 = numpyro.sample("phi0", dist.Normal(-1, 1))

    with numpyro.plate("phi_plate", 2):
        phi = numpyro.sample("phi", dist.Normal(0, 1))

    eta0 = numpyro.sample("eta0", dist.Normal(0, 1))

    with numpyro.plate("eta_plate", 4):
        eta = numpyro.sample("eta", dist.Normal(0, 1))

    with numpyro.plate("data_plate", num_data):
        S_logits = phi0 + phi[0] * sex + phi[1] * parsmk
        S = numpyro.sample("S", dist.Bernoulli(logits=S_logits))

        logits = (
            eta0
            + gamma[id_individual]
            + eta[0] * sex
            + eta[1] * parsmk
            + eta[2] * (1 - sex) * wave
            + eta[3] * sex * wave
        )

        kwargs = {"probs": 1.0} if S else {"logits": logits}

        numpyro.sample("obs", dist.Bernoulli(**kwargs), obs=y)

you need to use constructs like jnp.where instead of if/else. for an example see here

Sorry, I am trying to avoid the use of either if/else or a jnp.where. I could be wrong, but I don’t think it is necessary. My first post was an attempt communicate what I was trying to code.

I switched out the kwargs = {"probs": 1.0} if S else {"logits": logits} with:

final_logits = (1 - S) * 100.0 + S * logits

But this model is still having issues reproducing the published results. I was able to get good agreement with the other models from the paper, it is just the mixture is harder to get right.

you might find it easier to use MixtureSameFamily

I could not get the MixtureSameFamily to work within a plate. I could post in another thread.

When I look at Carlin et al. it looks just like a two distribution Bernoulli mixture. So I just wrote out the mixture by hand, given the equation (Slide 4)

\mu = [\:logit^{-1}(\eta_0 + \gamma_i + \eta^Tx_{ij}), \:1 \times10^{-8}]

\pi = [\: logit^{-1}(\phi_0 + \phi^Tz_i), \: (1 - logit^{-1}(\phi_0 + \phi^Tz_i)) ]

p(\:x\: |\: \mu, \pi\:) = \pi_0\mu_0^x(1-\mu_0)^{1-x} + \pi_1\mu_1^x(1-\mu_1)^{1-x}

I am still having issues with reproducing the fit for their discrete mixture model. As you can see the subject SD \tau is having a lot issues. The lines in the below traces are the reported values from Carlin et al. Table 5.

You can see from the other models I can the model to agree pretty well, code in the gist. Do you have any suggestions? Maybe the mixture model is still not specified correctly?

Not sure if it helps but you can also use ZeroInflatedDistribution for this likelihood. Here gate is 1 - P(S=1) in your problem. You can also reparameterize the part

    sigma = numpyro.sample("sigma", dist.Exponential(1))

    with numpyro.plate("individual_plate", num_individual):
        gamma = numpyro.sample("gamma", dist.Normal(0, sigma))

into

    sigma = numpyro.sample("sigma", dist.Exponential(1))

    with numpyro.plate("individual_plate", num_individual):
        gamma_base = numpyro.sample("gamma_base", dist.Normal(0, 1))
        gamma = numpyro.deterministic("gamma", gamma_base * sigma)

or using

reparam_model = numpyro.handlers.reparam(model,
    config={"gamma": numpyro.infer.reparam.LocScaleReparam(0)})

Updated notebook

Thank you for the help. I did not know about the automatic re-parameterization part of the library. I will definitely start to utilize this more. The re-parameterization helped not only with the discrete mixture but also the logistic normal.

I switched out my janky Bernoulli mixture with the ZeroInflatedDistribution which I agree, is a good choice for the likelihood. When I use the ZeroInflatedDistribution with a Bernoulli as my base and 1 - P(S=1) as my gate I am still not reproducing the majority of the fit. Unfortunately, I have some parameters with different magnitudes and a change sign. I also tried switching to non informative priors, but this didn’t seem to make a difference toward reproducing the published fit.

There still might be something wrong in my discrete mixture model, but maybe the difference is due to the choice of inference, HMC versus Gibbs (BUGS)?

Typically, NUTS will give you the right results for this type of model. You might want to try to initialize NUTS at the desired value (using init_to_values strategy) to see if the posterior is multimodal. You can also compare log density of the desired value and the values given by NUTS. If NUTS gives better result, I’ll trust NUTS result. In that case, you might want to double check if your model matches the description in the paper.

Gotcha, I will make the comparisons and double check what I have with the paper. Thank you for the help.

@fehiepsi

Do you mean the init_param argument for mcmc.run(...) when using the NUTS kernel? Sorry, I couldn’t find any explicit reference to init_to_values in the docs.