Evaluation of Generative Model: Predictive Likelihood and Point Predictions

First of all, thank you for developing this great library and for the strong community support! I built a generative latent variable model:

    @config_enumerate
    def model(data=None):

        base_G_c = dist.Normal(torch.ones(self.n_c), torch.ones(self.n_c))
        lam_G_c = pyro.sample("lam_G_c", dist.TransformedDistribution(base_G_c, [transforms.OrderedTransform()]))

        base_Q_c = dist.Gamma(torch.ones(self.n_c), torch.ones(self.n_c))
        lam_Q_c = pyro.sample("lam_Q_c", dist.TransformedDistribution(base_Q_c, [transforms.OrderedTransform()]))

        base_T_c = dist.Normal(torch.ones(self.n_c) * (-1), torch.ones(self.n_c) * 1)
        lam_T_c = pyro.sample("lam_T_c", dist.TransformedDistribution(base_T_c, [transforms.OrderedTransform(),transforms.SigmoidTransform()]))

        pi_Z_c = pyro.sample("pi_Z_c", dist.Dirichlet(torch.ones(self.n_c) / self.n_c))


        with pyro.plate('data_plate', data["mask"]["Q"].shape[0]):

            Z = pyro.sample('Z', dist.Categorical(pi_Z_c), infer={"enumerate": "parallel"})

            G = pyro.sample('G', dist.Normal(Vindex(lam_G_c)[Z], torch.ones(1)).mask(data["mask"]["G"]),obs=data["data"]["G"])
            Q = pyro.sample('Q', dist.Poisson(Vindex(lam_Q_c)[Z]).mask(data["mask"]["Q"]), obs=data["data"]["Q"])
            T = pyro.sample('T', dist.Binomial(probs=Vindex(lam_T_c)[Z], total_count=data["n_k"]["T"] - 1).mask(data["mask"]["T"]), obs=data["data"]["T"])

The model works fine and I can infer posterior params via MCMC. Now, I am trying to evaluate the model based on its (A) predictive likelihood and based on (B) point predictions on held-out data for “G”, “Q” and “T”.

(A) Predictive Likelihood_____________

Since the Predictive class does not return traces anymore and Predictive.get_vectorized_trace throws me the error “Shape mismatch inside plate(‘data_plate’) at site Z dim -1” (any ideas how to resolve this?), I am resorting to

conditioned_model = poutine.condition(model, mcmc_posterior_params)
trace = poutine.trace(conditioned_model).get_trace(data)
trace.compute_log_prob()

where mcmc_posterior_params is e.g. the mean of the mcmc posterior params (I am aware of this discussion). However, I am worried that this is not doing what I want: I would like to obtain the log_prob for each of the observed sites, while knowing the other sites, e.g. log p(G* | Q*, T*, thetas) in the generative model. Do replay and block or infer_discrete help here?

(B) Point Predictions_____________

Inspired by several tutorials, I would like to evaluate the generative model by predicting one observed site, e.g. G = None, when passing only the values of the observed sites “T” and “Q”. This can be seen as an imputation task, thus, I set impute_data["G"]["data"] = None and impute_data["G"]["mask"] = [False, False, ....] while keeping the true values in Q and T, e.g. impute_data["Q"]["data"] = [0,0,0,1,....]. Next, I am doing

Predictive(model, mcmc_posterior_params)(impute_data)

and take the mean or mode over the predicted values per site and sample (MAP style). In the best case, this results in predictions equalling a majority vote baseline. Do you have any insights on how to get it right? I assume, I somehow need to condition on observed sites and params agains.

Since these discussions and tutorials like the Baseball one did not help me, I would so much appreciate your help! My case is slightly different to discriminative settings such as logistic regression since it is a generative model and the data features “X” (“at_bats” in the tutorial) are generated, observed sample sites and targets at the same time.

It seems that your likelihoods are independent from each other (conditioned on latent values) so if you already know thetas, you can calculate the log likelihood of G* directly (unless you don’t trust thetas samples and want to do inference again to find p(thetas|Q*,T*,G,Q,T) then using those new thetas to get log likelihood of G*).

It seems that Predictive does not support infer discrete yet. You will need to use infer discrete to get samples for Z, then merge them with the thetas samples (obtained by using NUTS). Then I guess you can use Predictive (making sure to remove config_enumerate and set infer={} at Z site).

1 Like

Thank you so much for your help @fehiepsi!

Would you mind sharing short code snippets to illustrate your suggestions?

I am aware that G is conditionally independent of Q and T given Z. However, I am looking for a way to set Z given my inferred thetas as well as Q and T to then predict G. Such a model presumably performs better than a baseline without the observed sites Q and T, where G is only conditioned on Z. We need the knowledge of at least one other observed variable to predict data points of G. However, I am not sure if rerunning inference with thetas as well as Q* and T* is the way to go. Would you agree?

I think you can use infer discrete to get z values: there are a couple of examples in the tutorial page Search — Pyro Tutorials 1.8.4 documentation

I think the code in the predict function of epidemiology is closest to your problem. Does infer discrete work for your model?

1 Like

Thanks for bearing with me @fehiepsi!

For instance, the following code is working for me and improves log_prob_sum upon the version without infer_discrete:

conditioned_model = poutine.condition(model, mcmc_posterior_params)
infer_discrete_conditioned_model = infer_discrete(conditioned_model, first_available_dim=-2, temperature=1)
trace = poutine.trace(infer_discrete_conditioned_model).get_trace(heldout_data)

Though, I am still not sure, if this is doing the job right? Nothing changes when adding inferred Z sites to mcmc_posterior_params in the conditioning.

I think it is doing the right thing with the heldold_data that you described previously (which is pretty neat to me)
impute_data["G"]["data"] = None and impute_data["G"]["mask"] = False (or an array of False)
Just to be sure, you can build a toy example for a model (with 1 discrete latent and two simple likelihoods) to test the masking logic :wink:

It is better to print out trace.format_shapes() to see if we are getting the right shapes for all the sites. Sometimes, the broadcasting job can do the wrong things if the model is not written correctly.

2 Likes

@Nik Looking at the epidemiology example again, it seems that you’ll need to add particle plate to the conditioned_model:

particle_plate = pyro.plate("particles", num_samples, dim=-2)

and set first_available_dim=-3. Anyway, it is best to use format_shapes to make sure that you are getting the right trace. MCMC samples might not return the desired shapes for “vectorized” conditioning, e.g. lam_G_c might need to have shapes (num_samples, 1), rather than num_samples, so you might need to reshape them.

1 Like

Thank you very much for following up @fehiepsi! Unfortunately, I still could not get a toy model to run as I expect it to.

The idea to wrap the model with another plate to marginalise out the samples is actually also done in Predictive:
vectorize = pyro.plate("_num_predictive_samples", num_samples, dim=-max_plate_nesting-1)

Unfortunately, I do not manage to find the right shapes. Simply reshaping the latents e.g. lam_G_c into (num_samples, 1) does not work. Would you mind providing a short code snippet, please? :cry:

@Nik Not sure what’s the issue that you got. Here is a simple code

@config_enumerate
def model():
    x = pyro.sample("x", dist.Normal(0,1))
    with pyro.plate("N", 10):
        c = pyro.sample("c", dist.Categorical(logits=torch.ones(3)))
        y = pyro.sample("y", dist.Normal(c, 1), obs=torch.ones(10))
    return x, c, y


particles = pyro.plate("M", 20, dim=-2)
x, c, y = pyro.poutine.trace(particles(model))()
print(x.shape, c.shape, y.shape)
conditioned = pyro.condition(particles(model), data={"x": x})
x, c, y = infer_discrete(conditioned, first_available_dim=-3)()
print(c.shape)

Note that x shape is (20, 1) but if you get it from MCMC, you will get the shape (20,). That’s why you need to reshape.

Shouldn’t you need to reshape that site to (num_samples, 1, n_c)? I think you should also use Vindex(lam_foo)[…, z]. Again, I’m not sure what’s the problem that you got. It would be nice if you can provide some reproducible code that’s not working.

Hi @fehiepsi,

Again, thank you so much for your perseverance! I realised (num_samples, 1, n_c) as well as the particle_plate.

Further, I have now prepared a fully self-contained Jupyter notebook which outlines the task and problems I am facing:
https://github.com/niklasstoehr/pyro_model/blob/master/Generative%20Discrete%20Latent%20Variable%20Model.ipynb

First, I generate data using the gq_model and some known true_params. The generated data G and Q are thus correlated which is important since we later want to infer one from the other.

The mcmc summary shows that correct posterior params are inferred! Next, I compute the predictive likelihood. Therefore, I first infer P(Z | Q, T, params) by setting G to None. Next, I do infer P(G | Z, params) and compute the exponentiated log likelihood and point predictions via mse and f1.

I compare the results against g_model and q_model which only have access to either G or Q. Everything goes to plan as these models perform worse in the “imputation” task.

You can try to use the true samples to verify the conditional logic first.

    params = mcmc.get_samples()
    params = {k: true_params[k].expand(v.shape) for k, v in params.items()}

It can be that the posterior samples are not good? You can try to increase the numbers of mcmc warmup/samples to (500, 500). If things still do not work, maybe we have some problems with the approach…

1 Like

@fehiepsi and I have concluded that this problem seems to be resolved! Please feel free to reach out should you have any questions!

def impute_data(data, sites=[]):
    
    impute_data = copy.deepcopy(data)
    for k in sites:
        impute_data['data'][k] = None ## remove data site
        impute_data['mask'][k] = torch.zeros(impute_data["mask"][k].shape[0]).bool() ## set data mask to False
    return impute_data


def evaluate_point_predictons(pred_sites, gt_data):
    pred_comp = dict()

    for k in gt_data["data"].keys():
        if k in pred_sites.keys():
            
            hat_data = torch.mean(pred_sites[k], dim = 0)
            
            mse = mean_squared_error(gt_data["data"][k].type(torch.float), hat_data.type(torch.float))
            f1 = f1_score(gt_data["data"][k].type(torch.int), hat_data.type(torch.int), average='weighted')
            print(f"{str(k)}: mse {mse}, weighted f1 {f1}")


def compute_exp_pred_lik(post_loglik):
    
    ### computes pointwise expected log predictive density at each data point
    sample_mean_exp_n = torch.mean(torch.exp(post_loglik), 0)
    exp_log_lik = torch.exp(torch.mean(torch.log(sample_mean_exp_n), axis=0))
    #exp_log_density[k] = (post_loglik[k].logsumexp(0) - math.log(post_loglik[k].shape[0])).sum().item()
    return exp_log_lik.item()



def evaluate_pred_lik(mcmc, model, data, sites = ["G", "Q"]):

    ### computes predictive likelihood
    params = mcmc.get_samples()    
    num_samples = list(params.values())[0].shape[0] 
    sample_plate = pyro.plate("samples", num_samples, dim=-2)
    
    pred_sites = dict()
    
    for site in sites: ## loop through observed sites

        ## infer P(Z | Q, T, params)______
        infer_z_model = poutine.condition(model, params)
        infer_z_model = sample_plate(infer_z_model)
        infer_z_model = infer_discrete(infer_z_model, first_available_dim=-3, temperature=1)
        
        imputed_data = impute_data(data, sites=[site]) ## impute observed site
        impute_trace = poutine.trace(infer_z_model).get_trace(imputed_data)
        Z = impute_trace.nodes["Z"]["value"]
        Z_params = {"Z": Z, **params}

        ## infer P(G | Z, params)______
        infer_site_model = poutine.condition(model, Z_params)
        infer_site_model = sample_plate(infer_site_model)
        #infer_site_model = infer_discrete(infer_site_model, first_available_dim=-3, temperature=1)
        trace = poutine.trace(infer_site_model).get_trace(test_data)
        trace.compute_log_prob()
        
        exp_pred_lik = compute_exp_pred_lik(trace.nodes[site]["log_prob"])
        print(f"{site}: exp_pred_lik {exp_pred_lik}")
            
        pred_sites[site] = impute_trace.nodes[site]["value"]
        
    return pred_sites
2 Likes