Logistic Regression NUTS on Titanic Dataset extremely slow, but the time it takes also wildly varies

Hey guys.

I want to try bayesian logistic regression on the titanic dataset (using only 6 numeric features) but i found it to be extremely slow. Using NUTS for sampling the posterior (num_samples = 500, warmup_steps = 100) can take 20 minutes! but SOMETIMES it can take 1-2 minutes. That is still relatively slow, but how can this difference be so big?

Is the way I set up my model not optimal?

Here’s the code I used.


class LogisticRegressionV2(object):
    def __init__(self, C = 1.0):
        self.C = C # Inverse regularization strength (C = 1/lambda)
        
        self.n_samples = None 
        self.n_features = None #Dimension of the data (number of features)
        
        self.hmc_posterior = None
        self.coefficients_summary = None 
        
        self.marginal = None # Shape (num_mcmc_samples, n_features + 1)
        self.coefficients_MAP = None # Shape (1, n_features + 1)

        
    def prep_data(self, X):
        """
        Parameters
        
        X: numpy array, shape (n_samples, n_features)
        
        Returns
        
        X_plus_intercept: numpy array, shape (n_features +1, n_samples) (transposed and added ones for the intercept terms)
        """
        
        self.n_samples = X.shape[0]
        self.n_features = X.shape[1]
        
        self.coefficients_names = ["beta_%s"%i for i in range(self.n_features + 1)]
        
        # Add intercept
        X_plus_intercept = np.ones((self.n_features +1, self.n_samples))
        X_plus_intercept[-self.n_features:, :] = X.T.squeeze()

        return X_plus_intercept
        
    def logit_model(self, X_tensor = None, Y_tensor = None):
        """
        Parameters
        
        X_tensor: torch tensor, shape (n_features +1, n_samples) (includes intercept term. Result of prep_data(X))
        Y_tensor: torch tensor, shape (1, n_samples )

        """    
        gamma = np.sqrt(self.C)
        
        beta_prior = dist.Normal(torch.zeros(1, self.n_features + 1), torch.ones(1, self.n_features + 1) * gamma)
        beta = pyro.sample("beta", beta_prior)

        z = torch.mm(beta, X_tensor)
        p = torch.sigmoid(z)

        likelihood = dist.Bernoulli(p)

        pyro.sample("obs", likelihood, obs = Y_tensor)

    def fit(self, X, Y, num_samples = 100, warmup_steps = 50):
        """
        Parameters
        
        X: numpy array, shape (n_samples, n_features)
        Y: numpy array, shape (n_samples, ) or shape(n_samples, 1)

        """
        
        # X gets transformed into shape (n_features +1, n_samples). 
        # The extra dimension comes from the intercept terms, and the transpose allows matrix multiplication later.
        X = self.prep_data(X)
        Y = Y.reshape(-1, 1)
        
        X_tensor = torch.tensor(X, dtype = torch.float)
        Y_tensor = torch.tensor(Y, dtype = torch.float)
        Y_tensor = Y_tensor.transpose(0, 1)        
        nuts_kernel = NUTS(self.logit_model, adapt_step_size = True, jit_compile = False, adapt_mass_matrix=False)
        hmc_posterior = MCMC(nuts_kernel, num_samples = num_samples, warmup_steps = warmup_steps, num_chains = 1)\
                        .run(X_tensor, Y_tensor)

To load the dataset


df = pd.read_csv("http://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv")
df['sex'] = df.Sex.apply(lambda s: 1 if s=='female' else 0)
df['has siblings or spouses aboard'] = (df['Siblings/Spouses Aboard'] > 0).astype(int)
df['has parents or children aboard'] = (df['Parents/Children Aboard'] > 0).astype(int)
df['has family aboard'] = (df['Siblings/Spouses Aboard'] + df['Parents/Children Aboard'] > 0).astype(int)

df_train = df.sample(frac=0.5, replace=False, random_state=42)
df_test  = df.loc[df.index.difference(df_train.index)]

features = ['Pclass', 'Age', 'Fare', 'Siblings/Spouses Aboard', 
            'Parents/Children Aboard', 'sex']
target = 'Survived'

X_train = df_train[features].values
y_train = df_train[target].values

X_test = df_test[features].values
y_test = df_test[target].values

Running it:

lr_v2 = LogisticRegressionV2()
lr_v2.fit(X_train, y_train, num_samples = 500, warmup_steps = 100)

Is there anything wrong/suboptimal with my implementation? Help would be greatly appreciated!

Hi @NedimB, I think that it might be slow because NUTS spends so many steps in each trajectory. Could you try the following adjustments and let me know if they helps?

  • use likelihood = dist.Bernoulli(logits=z) (because it is less stable to working with probs for Bernoulli distribution)
  • use adapt_mass_matrix=True (to learn the underlying geometry of posterior - without it, turning condition in a no-U-turn sampler (NUTS) would be less meaningful)
  • enable jit_compile=True to improve the speed

Hey @fehiepsi.

Thanks for the tips.

I’m currently on another machine (which is a bit faster), but I have tried your tips and at least I get a more consistent timing now (around 30 seconds). But from memory, this still seems relatively slow (compared to PyMC3). I’ll also have to try it on the slower machine on monday!

@fehiepsi

I only just now noticed that the sampling during the warm-up phase is ~10x slower. Is that expected?

After the warmup phase it peaks at around 35 iterations per second.

That is entirely possible since during warmup, the algorithm tries to adapt the step size as it is running NUTS, and a small step size will result in slow progress.

But from memory, this still seems relatively slow (compared to PyMC3).

Pyro’s NUTS implementation is unfortunately slower than PyMC3 or Stan for small models due to certain unavoidable technical restrictions in the current PyTorch JIT. While 30 seconds is not as much of an issue, I am quite sure that you will see competitive performance if you were to use our NumPy backend with NumPyro that we recently released. You might need to make some small changes to your model and inference API calls that are documented in the repo.

@neerajprad Thanks for your answer!

But do I understand correctly that NumPyro does not work on Windows?

I think the issue there is that you will need to build jaxlib yourself - https://github.com/google/jax/issues/438. NumPyro and JAX are otherwise pure python libraries that should work on Windows. Check out the last comment on that thread and let me know if that works for you.

I think we should add a disclaimer to our repo about limited support on Windows.

Hey @neerajprad .

Thank you again for your response.

Ah, I see… sadly it does not appear to be possible to do what the last commenter suggests while on Windows 7 (which I am using).

I only just now noticed that the sampling during the warm-up phase is ~10x slower. Is that expected?

Yes, it might be expected. Adapting scheme in Pyro and Stan is different from PyMC3, so it is highly that PyMC3 spends less leapfrog steps for each trajectory during warmup phase. :slight_smile:

@fehiepsi @neerajprad

Hey guys.

This is just super strange.

I don’t change anything in the code and just run it the next day on the same machine, and suddenly it’s 2-3 times slower. Why does this happen?

It is strange. Maybe the data is changed or some packages are updated which have regression. Did you get the same result between two runs?

@fehiepsi No the data has not changed, and neither have any packages. But it happened multiple times now (also I mentioned earlier that on the other machine it would take either 1.5-2 minutes or 20 minutes sometimes)

@NedimB My last hypothesis is to set a random seed: pyro.set_rng_seed(0) at the beginning of your script. Maybe you just get some bad initial values.

I actually do that as well! has always been the same seed

Hi @NedimB, I tried to run your script (thanks for providing a very easy-to-replicate script) and observed a consistence 24s with seed(0). I have no other idea for the problem which you get beyond “another” last hypothesis: this is a tqdm issue. Could you try to disable progbar MCMC(..., disable_progbar=True)?

I just think that this might be a Windows issue.

@neerajprad Do you have any other ideas on why the running time is fluctuated between 2m to 20m?

Hm disabling the progress bar actually speeds up some experiments, but I have not had the opportunity to run it on this machine that shows the inconsistency yet. I will report back when I have :slight_smile: