Hierarchical model guide with plate

Hi all, I’ve written a model that runs beautifully with NUTS in numpyro, but I’m trying to write a guide to get it running using SVI in pyro and I’ve been struggling. The model estimates scores based on football match results by estimating the attacking and defensive abilities of each team:

def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    mu_att = pyro.sample("mu_att", dist.Normal(0.0, 1.0))
    sd_att = pyro.sample("sd_att", dist.StudentT(3.0, 0.0, 2.5))
    mu_def = pyro.sample("mu_def", dist.Normal(0.0, 1.0))
    sd_def = pyro.sample("sd_def", dist.StudentT(3.0, 0.0, 2.5))

    home = pyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage

    nt = len(np.unique(home_id))

    # team-specific model parameters
    with pyro.plate("plate_teams", nt):
        attack = pyro.sample("attack", dist.Normal(mu_att, sd_att))
        defend = pyro.sample("defend", dist.Normal(mu_def, sd_def))

    # likelihood
    theta1 = torch.exp(home + attack[home_id] - defend[away_id])
    theta2 = torch.exp(attack[away_id] - defend[home_id])

    with pyro.plate("data", len(home_id)):
        pyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        pyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)

I’ve attempted to write the guide as:

def guide(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    mu_locs = pyro.param("mu_loc", torch.tensor(0.0).expand(3))
    mu_scales = pyro.param(
        "mu_scale", torch.tensor(0.1).expand(3), constraint=constraints.positive
    )

    sd_dfs = pyro.param(
        "sd_df", torch.tensor(2.0).expand(3), constraint=constraints.positive
    )
    sd_scales = pyro.param(
        "sd_scale", torch.tensor(0.1).expand(3), constraint=constraints.positive
    )

    pyro.sample("mu_att", dist.Normal(mu_locs[0], mu_scales[0]))
    pyro.sample("mu_def", dist.Normal(mu_locs[1], mu_scales[1]))

    pyro.sample("sd_att", dist.StudentT(sd_dfs[0], torch.tensor(0.0), sd_scales[0]))
    pyro.sample("sd_def", dist.StudentT(sd_dfs[1], torch.tensor(0.0), sd_scales[1]))

    pyro.sample("home", dist.Normal(mu_locs[2], mu_scales[2]))  # home advantage

    nt = len(np.unique(home_id))

    mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand(2))
    mu_team_scales = pyro.param(
        "mu_team_scale", torch.tensor(0.1).expand(2), constraint=constraints.positive
    )

    # team-specific model parameters
    with pyro.plate("plate_teams", nt):
        pyro.sample("attack", dist.Normal(mu_team_locs[0], mu_team_scales[0]))
        pyro.sample("defend", dist.Normal(mu_team_locs[1], mu_team_scales[1]))

following the rule to write a pyro.param for everything that isn’t obs. But I get the error

ValueError: The parameter scale has invalid values
   Trace Shapes:     
    Param Sites:     
   Sample Sites:     
     mu_att dist    |
           value    |
     sd_att dist    |
           value    |
     mu_def dist    |
           value    |
     sd_def dist    |
           value    |
       home dist    |
           value    |
plate_teams dist    |
           value 20 |

I’ve actually tried many different structures sizes for the (based on nt = 20) attack and defend parameters in the plate, but with no luck. Any help with this guide is greatly appreciated, especially with the HalfStudentT distribution (do I need a pyro.param() on the scale as well?). If there is an appropriate AutoGuide, that is also great. Cheers

Full reproducible code here.

i’m confused StudentT has support on the full real line not sure how you got HMC to work unless you got lucky and avoided hitting negative values?

shouldn’t mu_team_locs etc be 20 x 2 dimensional or the like?

Yes, perhaps I got lucky with where the chains started. Does pyro/numpyro have support for the Half-Student-t (and if so, how would you write a guide for it?)

I’ve tried replacing the plate section of the guide with

    nt = len(np.unique(home_id))

    mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand((nt, 2)))
    mu_team_scales = pyro.param(
        "mu_team_scale", torch.tensor(0.1).expand((nt, 2)), constraint=constraints.positive
    )

    # team-specific model parameters
    with pyro.plate("plate_teams", nt):
        pyro.sample("attack", dist.Normal(mu_team_locs[:, 0], mu_team_scales[:, 0]))
        pyro.sample("defend", dist.Normal(mu_team_locs[:, 1], mu_team_scales[:, 1]))

and the same error remains. I find 20x2 a bit confusing as well, because in the model, it is written as attack = pyro.sample("attack", dist.Normal(mu_att, sd_att)) – i.e, all the attack params are taken from the same normal distribution with one mean parameter and one sd parameter, not 20.

how about something like this?

def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    mu_att = pyro.sample("mu_att", dist.Normal(0.0, 1.0))
    sd_att = pyro.sample("sd_att", dist.LogNormal(0.0, 1.0))
    mu_def = pyro.sample("mu_def", dist.Normal(0.0, 1.0))
    sd_def = pyro.sample("sd_def", dist.LogNormal(0.0, 1.0))

    home = pyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage
    nt = len(np.unique(home_id))

    with pyro.plate("plate_teams", nt):
        attack = pyro.sample("attack", dist.Normal(mu_att, sd_att))
        defend = pyro.sample("defend", dist.Normal(mu_def, sd_def))

    theta1 = torch.exp(home + attack[home_id] - defend[away_id])
    theta2 = torch.exp(attack[away_id] - defend[home_id])

    with pyro.plate("data", len(home_id)):
        pyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        pyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)

def guide(home_id, away_id, score1_obs=None, score2_obs=None):
    mu_locs = pyro.param("mu_loc", torch.tensor(0.0).expand(5))
    mu_scales = pyro.param(
        "mu_scale", torch.tensor(0.1).expand(5), constraint=constraints.positive
    )

    pyro.sample("mu_att", dist.Normal(mu_locs[0], mu_scales[0]))
    pyro.sample("mu_def", dist.Normal(mu_locs[1], mu_scales[1]))
    pyro.sample("sd_att", dist.LogNormal(mu_locs[2], mu_scales[2]))
    pyro.sample("sd_def", dist.LogNormal(mu_locs[3], mu_scales[3]))
    pyro.sample("home", dist.Normal(mu_locs[4], mu_scales[4]))  # home advantage

    nt = len(np.unique(home_id))

    mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand(2, 20))
    mu_team_scales = pyro.param(
        "mu_team_scale", torch.tensor(0.1).expand(2, 20), constraint=constraints.positive
    )

    with pyro.plate("plate_teams", nt):
        pyro.sample("attack", dist.Normal(mu_team_locs[0], mu_team_scales[0]))
        pyro.sample("defend", dist.Normal(mu_team_locs[1], mu_team_scales[1]))


svi = SVI(model=model, guide=guide, optim=Adam({"lr": 0.005}), loss=Trace_ELBO())
pyro.clear_param_store()  # clear global parameter cache
pyro.set_rng_seed(1)

num_iterations = 1000
advi_loss = []
for j in range(num_iterations):
    loss = svi.step(
        home_id=torch.tensor(train["Home_id"]),
        away_id=torch.tensor(train["Away_id"]),
        score1_obs=torch.tensor(train["score1"]),
        score2_obs=torch.tensor(train["score2"]),
    )
    advi_loss.append(loss)
    if j % 100 == 0:
        print("[iteration %4d] loss: %.4f" % (j + 1, loss))

Thank you so much for the suggestion – it now runs. I’ve set it up as below:

def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    mu_att = pyro.sample("mu_att", dist.Normal(0.0, 1.0))
    sd_att = pyro.sample("sd_att", dist.StudentT(3.0, 0.0, 2.5))
    mu_def = pyro.sample("mu_def", dist.Normal(0.0, 1.0))
    sd_def = pyro.sample("sd_def", dist.StudentT(3.0, 0.0, 2.5))

    home = pyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage

    nt = len(np.unique(home_id))

    # team-specific model parameters
    with pyro.plate("plate_teams", nt):
        attack = pyro.sample("attack", dist.Normal(mu_att, sd_att))
        defend = pyro.sample("defend", dist.Normal(mu_def, sd_def))

    # likelihood
    theta1 = torch.exp(home + attack[home_id] - defend[away_id])
    theta2 = torch.exp(attack[away_id] - defend[home_id])

    with pyro.plate("data", len(home_id)):
        pyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        pyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)


def guide(home_id, away_id, score1_obs=None, score2_obs=None):
    mu_locs = pyro.param("mu_loc", torch.tensor(0.0).expand(5))
    mu_scales = pyro.param(
        "mu_scale", torch.tensor(0.1).expand(5), constraint=constraints.positive
    )

    pyro.sample("mu_att", dist.Normal(mu_locs[0], mu_scales[0]))
    pyro.sample("mu_def", dist.Normal(mu_locs[1], mu_scales[1]))
    pyro.sample("sd_att", dist.LogNormal(mu_locs[2], mu_scales[2]))
    pyro.sample("sd_def", dist.LogNormal(mu_locs[3], mu_scales[3]))
    pyro.sample("home", dist.Normal(mu_locs[4], mu_scales[4]))  # home advantage

    nt = len(np.unique(home_id))

    mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand(2)) 
    mu_team_scales = pyro.param(
        "mu_team_scale",
        torch.tensor(0.1).expand(2),
        constraint=constraints.positive,
    )

    with pyro.plate("plate_teams", nt):
        pyro.sample("attack", dist.Normal(mu_team_locs[0], mu_team_scales[0]))
        pyro.sample("defend", dist.Normal(mu_team_locs[1], mu_team_scales[1]))

Two things:

  • Here, I’ve used mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand(2)) rather than mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand(2, nt)). I’m confused why they both ran. In my head, the attack nodes come from the same normal – attack = pyro.sample("attack", dist.Normal(mu_att, sd_att)) – and hence should only need one pyro.param, not nt (20). Do you know why they both ran and which is correct?
  • Here, using a StudentT prior with a LogNormal guide prevents negative values. If possible, I would like to use the HalfStudentT, because that is how I’ve set the model up using other PPLs and I’d like to compare directly. Does Pyro have support for this? I found this post #274 (referencing the stan post where I chose the HalfStudentT prior also) and was wondering whether this has been developed since

Cheers

i believe you can get HalfStudentT using the TransformedDistribution(base_dist, AbsTransform()) pattern

both variational assumptions are fine it’s just that one is more flexible. the more flexible one is what you want because all the posterior means etc of the different team attacks will be different unless there’s some accidental feature in the data that generates a tie.

Sorry if I have implemented this naively, but I simply changed pyro.sample("sd_def", dist.StudentT(3.0, 0.0, 2.5)) to pyro.sample("sd_def", dist.TransformedDistribution(dist.StudentT(3.0, 0.0, 2.5), dist.transforms.AbsTransform()),), which produced expected positive values when called on its own. However, when I did the below in SVI

def model(home_id, away_id, score1_obs=None, score2_obs=None):
    # priors
    mu_att = pyro.sample("mu_att", dist.Normal(0.0, 1.0))
    sd_att = pyro.sample(
        "sd_att",
        dist.TransformedDistribution(
            dist.StudentT(3.0, 0.0, 2.5), dist.transforms.AbsTransform()
        ),
    )
    mu_def = pyro.sample("mu_def", dist.Normal(0.0, 1.0))
    sd_def = pyro.sample(
        "sd_def",
        dist.TransformedDistribution(
            dist.StudentT(3.0, 0.0, 2.5), dist.transforms.AbsTransform()
        ),
    )

    home = pyro.sample("home", dist.Normal(0.0, 1.0))  # home advantage

    nt = len(np.unique(home_id))

    # team-specific model parameters
    with pyro.plate("plate_teams", nt):
        attack = pyro.sample("attack", dist.Normal(mu_att, sd_att))
        defend = pyro.sample("defend", dist.Normal(mu_def, sd_def))

    # likelihood
    theta1 = torch.exp(home + attack[home_id] - defend[away_id])
    theta2 = torch.exp(attack[away_id] - defend[home_id])

    with pyro.plate("data", len(home_id)):
        pyro.sample("s1", dist.Poisson(theta1), obs=score1_obs)
        pyro.sample("s2", dist.Poisson(theta2), obs=score2_obs)


def guide(home_id, away_id, score1_obs=None, score2_obs=None):
    mu_locs = pyro.param("mu_loc", torch.tensor(0.0).expand(5))
    mu_scales = pyro.param(
        "mu_scale", torch.tensor(0.1).expand(5), constraint=constraints.positive
    )

    pyro.sample("mu_att", dist.Normal(mu_locs[0], mu_scales[0]))
    pyro.sample("mu_def", dist.Normal(mu_locs[1], mu_scales[1]))
    pyro.sample("sd_att", dist.LogNormal(mu_locs[2], mu_scales[2]))
    pyro.sample("sd_def", dist.LogNormal(mu_locs[3], mu_scales[3]))
    pyro.sample("home", dist.Normal(mu_locs[4], mu_scales[4]))  # home advantage

    nt = len(np.unique(home_id))

    mu_team_locs = pyro.param("mu_team_loc", torch.tensor(0.0).expand(2))  # , nt))
    mu_team_scales = pyro.param(
        "mu_team_scale",
        torch.tensor(0.1).expand(2),  # , nt),
        constraint=constraints.positive,
    )

    with pyro.plate("plate_teams", nt):
        pyro.sample("attack", dist.Normal(mu_team_locs[0], mu_team_scales[0]))
        pyro.sample("defend", dist.Normal(mu_team_locs[1], mu_team_scales[1]))


svi = SVI(model=model, guide=guide, optim=Adam({"lr": 0.001}), loss=Trace_ELBO())

pyro.clear_param_store()  # clear global parameter cache
pyro.set_rng_seed(1)

num_iterations = 1000
advi_loss = []
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(
        home_id=torch.tensor(train["Home_id"]),
        away_id=torch.tensor(train["Away_id"]),
        score1_obs=torch.tensor(train["score1"]),
        score2_obs=torch.tensor(train["score2"]),
    )
    advi_loss.append(loss)
    if j % 100 == 0:
        print("[iteration %4d] loss: %.4f" % (j + 1, loss))

I got the error

# grab a trace from the generator
...
model_trace, guide_trace = get_importance_trace(
...
    214                 if "log_prob" not in site:
    215                     try:
--> 216                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
    217                     except ValueError as e:
    218                         _, exc_value, traceback = sys.exc_info()

~/opt/anaconda3/envs/bayesian/lib/python3.8/site-packages/torch/distributions/transformed_distribution.py in log_prob(self, value)
    143             x = transform.inv(y)
    144             event_dim += transform.domain.event_dim - transform.codomain.event_dim
--> 145             log_prob = log_prob - _sum_rightmost(transform.log_abs_det_jacobian(x, y),
    146                                                  event_dim - transform.domain.event_dim)
    147             y = x

~/opt/anaconda3/envs/bayesian/lib/python3.8/site-packages/torch/distributions/transforms.py in log_abs_det_jacobian(self, x, y)
    179         Computes the log det jacobian `log |dy/dx|` given input and output.
    180         """
--> 181         raise NotImplementedError
    182 
    183     def __repr__(self):

NotImplementedError: 

Any idea if this is fixable?

Update: this fails when updating the model with the above HalfStudentT prior (rather than StudentT) on numpyro too, so I must have implemented it incorrectly. Help is much much much appreciated

if think you probably need to define a custom subclass of AbsTransform that includes

def log_abs_det_jacobian(self, x, y):
   return torch.zeros_like(x)

that should presumably work for your purposes

1 Like

Great, thank you. For those reading this in the future (unless HalfStudentT now exists), for pyro:

class cAbsTransform(dist.transforms.AbsTransform):
    def log_abs_det_jacobian(self, x, y):
        return torch.zeros_like(x)

pyro.sample("a", dist.TransformedDistribution(dist.StudentT(3.0, 0.0, 2.5), cAbsTransform()))

and for numpyro

class cAbsTransform(dist.transforms.AbsTransform):
    def log_abs_det_jacobian(self, x, y, intermediates=None):
        return jnp.zeros_like(x)

pyro.sample("a", dist.TransformedDistribution(dist.StudentT(3.0, 0.0, 2.5), cAbsTransform()), rng_key)

I have one last question. I need two things from the model:

  1. Estimates for the attack and defend parameters from each team (in MCMC numpyro, I simply run fit = mcmc.get_samples() and get fit["attack"] and fit["defend"])
  2. Predictions for future games, given the teams (in numpyro, I would do predictive = Predictive(model, fit, return_sites=["s1", "s2"]) and predicted_score = predictive(random.PRNGKey(0), home_id=predict["Home_id"].values, away_id=predict["Away_id"].values) )

In SVI for pyro, this is not obvious. I can get the guide parameters using pyro.get_param_store().items(), but when I try Predictive(model, guide, num_samples=2000) I get 'function' object has no attribute 'items', presumably referring to the guide. How am I able to implement 1 and 2 for SVI in pyro?

Just a quick note: unless you want to work with unnormalized log probabilities, it is better to use FoldedDistribution or TruncatedDistribution (not available yet in Pyro).

1 Like

Thanks @fehiepsi , the full numpyro code is now here for future users.

If anyone has any insight to help with points 1 and 2 above to help me finish the pyro SVI equivalent, that would be super appreciated

posterior_fn = Predictions(guide, num_samples = 2000)
posterior = posterior_fn(prng_key, **model_kwargs)

Then 2 should work as you’ve written.