Logisitic Regression Not Matching PyMC3


Hey Team

I’m attempting to perform a simple logistic regression using the NUTS sampler following PyMC3’s tutorial for a comparison but am struggling to get similar results.

For ease of simplicity, I’ve attached a notebook to this post with runnable code. I think where I’m going wrong is when I define the model, in PyMC3 it’s defined as:

with pm.Model() as logistic_model:
    pm.glm.GLM.from_formula('income ~ age + age2 + educ + hours', data, family=pm.glm.families.Binomial())
    trace_logistic_model = pm.sample(2000, chains=1, tune=1000, init='adapt_diag')

When using Pyro, I define it as the following:

def model(income, hours, educ, age2, age):
    # cauchy distribution used to constrain intercept (a) to be positive.
    a = pyro.sample("a", dist.Cauchy(0., 1000.))
    b_h = pyro.sample("b_hours", dist.Normal(0., 1.))
    b_e = pyro.sample("b_educ", dist.Normal(0., 1.))
    b_a2 = pyro.sample("b_age2", dist.Normal(0., 1.))
    b_a = pyro.sample("b_age", dist.Normal(0., 1.))
    sigma = pyro.sample("sigma", dist.Uniform(0., 1.))
    mean = a + b_h * hours + b_e * educ + b_a2 * age2 + b_a * age
    with pyro.iarange("data", len(hours)):
      pyro.sample("obs", dist.Bernoulli(logits=mean.squeeze()), obs=income)

I’m unsure as to whether I need to squeeze the mean or if I need to pass it through a Sigmoid function. Looking at PyTorch’s source, I don’t believe I need to.

Any help and advice would be appreciated! I’ve looked through this post, but couldn’t solve it.

Thanks in advance.


Hi @JamesTrick, there are a few points that might make results in your tests different:

  • Pyro NUTS adaptation schedule is different from PyMC3 (we follow Stan to use window adaptation while in PyMC3, mass matrix is adapted per each warmup step).
  • It seems to me that PyMC3 GLM uses different priors w.r.t. Pyro model. In addition, I think that it is better to set priors Normal(0, 10) instead of Normal(0, 1)
  • Looking at the trace plot of PyMC3 model, the intercept samples do not move to the typical set yet. I guess that you have to use a large number of iterations.

But I think that if you run longer, you can get similar results to the PyMC3 tutorial. Please let me know if it is not the case, so I’ll try to see where is the problem.

Some suggestions which might be helpful for you:

  • pyro.iarange is deprecated; you can use pyro.plate instead
with pyro.plate("data"):
    pyro.sample("obs", dist.Bernoulli(logits=mean), obs=income)
  • it seems that you don’t need to squeeze the mean, using logits=mean as you did should be fine
  • Use HalfCauchy instead of Cauchy if you want to constrain intercept to be postive
  • sigma term is missing in the mean calculation; so I guess you can remove it
  • because the data is large, GPU might be helpful for this model