Is there a cleaner way to write the model and guide for a GLMM?

Hi there,

Here is the working code for a model and guide for a logistic mixed-effects model. The data and linear predictor eta are calculated in “vectorized” form, where y is a vector, and X and Z are matrices. I played around with pyro.plate versus .expand, which is how I got it working. I am curious if this is the best approach for dealing with shapes in models like these.


class SixCities: 

    def __init__(self, data):
        self.y = data['y']
        self.X = data['X']
        self.Z = data['Z']

    def model(self, y, X, Z):
        zeta = pyro.sample('zeta', dist.Normal(0, 10))
        scale = pyro.deterministic('scale', torch.exp(-zeta))
        with pyro.plate('b_plate', self.Z.shape[1]):
            b = pyro.sample('b', dist.Normal(0, scale))
        with pyro.plate('beta_plate', self.X.shape[1]):
            beta = pyro.sample('beta', dist.Normal(0, 10))
        eta = pyro.deterministic('eta', self.X @ beta.reshape(-1, 1) + self.Z @ b.reshape(-1, 1))
        with pyro.plate('y_plate', self.y.shape[0]):
            pyro.sample('y', dist.Bernoulli(logits=eta.flatten()), obs=self.y)

    def guide(self, y, X, Z):
        zeta_loc = pyro.param('zeta_loc', torch.zeros(1))
        zeta_scale = pyro.param('zeta_scale', torch.ones(1), constraint=dist.constraints.positive)
        beta_loc = pyro.param('beta_loc', torch.zeros(self.X.shape[1]))
        beta_scale = pyro.param('beta_scale', torch.ones(self.X.shape[1]), constraint=dist.constraints.positive)
        b_loc = pyro.param('b_loc', torch.zeros(self.Z.shape[1]))
        b_scale = pyro.param('b_scale', torch.ones(self.Z.shape[1]), constraint=dist.constraints.positive)
        pyro.sample('zeta', dist.Normal(zeta_loc, zeta_scale))
        with pyro.plate('beta_plate', self.X.shape[1]):
            pyro.sample('beta', dist.Normal(beta_loc, beta_scale))
        with pyro.plate('b_plate', self.Z.shape[1]):
            pyro.sample('b', dist.Normal(b_loc, b_scale))