Hi there,
Here is the working code for a model and guide for a logistic mixed-effects model. The data and linear predictor eta
are calculated in “vectorized” form, where y is a vector, and X and Z are matrices. I played around with pyro.plate
versus .expand
, which is how I got it working. I am curious if this is the best approach for dealing with shapes in models like these.
Thanks.
class SixCities:
def __init__(self, data):
self.y = data['y']
self.X = data['X']
self.Z = data['Z']
def model(self, y, X, Z):
zeta = pyro.sample('zeta', dist.Normal(0, 10))
scale = pyro.deterministic('scale', torch.exp(-zeta))
with pyro.plate('b_plate', self.Z.shape[1]):
b = pyro.sample('b', dist.Normal(0, scale))
with pyro.plate('beta_plate', self.X.shape[1]):
beta = pyro.sample('beta', dist.Normal(0, 10))
eta = pyro.deterministic('eta', self.X @ beta.reshape(-1, 1) + self.Z @ b.reshape(-1, 1))
with pyro.plate('y_plate', self.y.shape[0]):
pyro.sample('y', dist.Bernoulli(logits=eta.flatten()), obs=self.y)
def guide(self, y, X, Z):
zeta_loc = pyro.param('zeta_loc', torch.zeros(1))
zeta_scale = pyro.param('zeta_scale', torch.ones(1), constraint=dist.constraints.positive)
beta_loc = pyro.param('beta_loc', torch.zeros(self.X.shape[1]))
beta_scale = pyro.param('beta_scale', torch.ones(self.X.shape[1]), constraint=dist.constraints.positive)
b_loc = pyro.param('b_loc', torch.zeros(self.Z.shape[1]))
b_scale = pyro.param('b_scale', torch.ones(self.Z.shape[1]), constraint=dist.constraints.positive)
pyro.sample('zeta', dist.Normal(zeta_loc, zeta_scale))
with pyro.plate('beta_plate', self.X.shape[1]):
pyro.sample('beta', dist.Normal(beta_loc, beta_scale))
with pyro.plate('b_plate', self.Z.shape[1]):
pyro.sample('b', dist.Normal(b_loc, b_scale))