Semisupervised VAE example not training

I have tried to use the code for the SS-VAE that I found in the pyro repo but I was not able to train the network on the CPU (on a Macbook Pro 2019 Catalina).
The weights of the encoder network are not updating at all.

This is my mini-example on the digit dataset from sklearn

import numpy as np
import torch
import pyro
from sklearn.datasets import make_blobs, load_digits
from pyro.optim import Adam
from pyro.infer import JitTraceEnum_ELBO, TraceEnum_ELBO, SVI
from pyro.infer import config_enumerate
import torch.nn as nn

from ss_vae_M2 import SSVAE # from the examples/vae/ folder

# Load the digits dataset from sklearn
one_hot = lambda x: nn.functional.one_hot(x, num_classes=10)
dict_digits = load_digits()
dict_digits.keys()
x = dict_digits["data"]
x = x - x.mean(1)[:,None]
x = x / x.std(1)[:, None]
x = (x > 0.2).astype(float)
x = torch.tensor(x, dtype=torch.float )
y = torch.tensor(dict_digits["target"])

# Set parameters
learning_rate = 0.09
beta1 = 0.9
jit = False
enum_discrete = "parallel"
num_epochs = 10
cuda = False
batch_size = 100
sup_num = 200

# Prepare objects and training
# setup the optimizer
pyro.clear_param_store()
ss_vae = SSVAE(output_size=10, input_size=64, z_dim=2, hidden_layers=[10,])

adam_params = {"lr": learning_rate, "betas": (beta1, 0.999)}
optimizer = Adam(adam_params)

# set up the loss(es) for inference. wrapping the guide in config_enumerate builds the loss as a sum
# by enumerating each class label for the sampled discrete categorical distribution in the model
guide = config_enumerate(ss_vae.guide, enum_discrete, expand=True)
Elbo = JitTraceEnum_ELBO if jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=1, strict_enumeration_warning=False)
loss_basic = SVI(ss_vae.model, guide, optimizer, loss=elbo)

loss_aux = SVI(
            ss_vae.model_classify, ss_vae.guide_classify, optimizer, loss=elbo
        )

# Train, and check if anything is happened with training

saved_encoder_out = ss_vae.encoder_y(x)
layer_to_check = list(ss_vae.encoder_y.parameters())[0]

for i in range(10000):
    ix = torch.randint(low=0, high=x.shape[0], size=(500,))
    xs = x[ix,:]
    ys = one_hot(y[ix])
    loss_aux.step(xs)
    loss_basic.step(xs, ys)
    
trained_encoder_out = ss_vae.encoder_y(x)
trained_layer = list(ss_vae.encoder_y.parameters())[0]

print(torch.allclose(trained_encoder_out, saved_encoder_out))
print(torch.allclose(trained_layer, layer_to_check))