Not sure if it helps but you can also use ZeroInflatedDistribution for this likelihood. Here gate
is 1 - P(S=1) in your problem. You can also reparameterize the part
sigma = numpyro.sample("sigma", dist.Exponential(1))
with numpyro.plate("individual_plate", num_individual):
gamma = numpyro.sample("gamma", dist.Normal(0, sigma))
into
sigma = numpyro.sample("sigma", dist.Exponential(1))
with numpyro.plate("individual_plate", num_individual):
gamma_base = numpyro.sample("gamma_base", dist.Normal(0, 1))
gamma = numpyro.deterministic("gamma", gamma_base * sigma)
or using
reparam_model = numpyro.handlers.reparam(model,
config={"gamma": numpyro.infer.reparam.LocScaleReparam(0)})