Fitted values do not respect the specified constraints in the guide

I have a model, which I define as follows:

def model(data: pd.DataFrame):
    prior_gene = pyro.sample("prior_gene", 
                             dist.Dirichlet(concentration=torch.ones(9) / 9.))
    prior_group = pyro.sample("prior_group", 
                              dist.Dirichlet(concentration=torch.ones(7) / 7.))
    prior_omim = pyro.sample("prior_omim", 
                             dist.Dirichlet(concentration=torch.ones(9) / 9.))
    
    prior_mean_pheno_x = pyro.sample('prior_mean_pheno_x',
                                     dist.Normal(loc=torch.zeros([9, 7, 9]), 
                                                 scale=torch.ones([9, 7, 9]) * 10).independent(3))
    
    prior_std_pheno_x = pyro.sample('prior_std_pheno_x',
                                    dist.HalfCauchy(torch.ones([9, 7, 9])).independent(3))

    for i in pyro.plate("data_loop", len(data)):
        row = data.iloc[i]
        
        current_gene = row['gene']
        current_group = row['group']
        current_omim = row['omim_id']
        
        pyro.sample("obs_genes_{}".format(i), dist.Categorical(prior_gene), obs=torch.tensor(current_gene))
        pyro.sample("obs_group_{}".format(i), dist.Categorical(prior_group), obs=torch.tensor(current_group))
        pyro.sample("obs_omims_{}".format(i), dist.Categorical(prior_omim), obs=torch.tensor(current_omim))
        
        # Conditional nodes
        pyro.sample("obs_pheno_x_{}".format(i), 
                    dist.Normal(loc=prior_mean_pheno_x[current_gene, current_group, current_omim],
                                scale=prior_std_pheno_x[current_gene, current_group, current_omim]),
                    obs=torch.tensor(row['pheno_x']))

Now I specify the corresponding guide function as:

def guide(data: pd.DataFrame):
    gene_dir = pyro.param("gene_dir", 
                          torch.ones(9) / 9.0, constraint=constraints.simplex)
    prior_gene = pyro.sample("prior_gene", dist.Dirichlet(concentration=gene_dir))
    
    group_dir = pyro.param("group_dir",
                           torch.ones(7) / 7.0, constraint=constraints.simplex)
    prior_group = pyro.sample("prior_group", dist.Dirichlet(concentration=group_dir))
    
    omim_dir = pyro.param("omim_dir",
                          torch.ones(9) / 9.0, constraint=constraints.simplex)
    prior_omim = pyro.sample("prior_omim", 
                             dist.Dirichlet(concentration=omim_dir))
    
    pheno_x_loc_loc = pyro.param('pheno_x_loc_loc', torch.zeros([9, 7, 9]))
    pheno_x_loc_scl = pyro.param('pheno_x_scl', torch.ones([9, 7, 9]) * 200, constraint=constraints.positive)
    prior_mean_pheno_x = pyro.sample('prior_mean_pheno_x', 
                                     dist.Normal(loc=pheno_x_loc_loc, scale=pheno_x_loc_scl).independent(3))
    
    pheno_x_scale = pyro.param('pheno_x_scl_scl', torch.ones([9, 7, 9]) * 2, constraint=constraints.positive)
    prior_std_pheno_x = pyro.sample('prior_std_pheno_x', 
                                    dist.Chi2(pheno_x_scale).independent(3))

So, I have defined this simplex constraint on the categorical distribution parameters.

I finally fit the model as:

svi = SVI(model_delphi,
          guide,
          optim.Adam({"lr": .0005, "betas": (0.93, 0.999)}),
          loss=Trace_ELBO(max_iarange_nesting=1))
for i in range(num_iters):
    elbo = svi.step(frame)
    print("Elbo loss: {}".format(elbo))

So, my loss is absolutely haywire. It starts high and wildly oscillates around a very high number.

Also, when I look at the fitted parameters with:

print(pyro.get_param_store().named_parameters())

I get something like:

dict_items([('gene_dir', tensor([-2.1962, -2.1982, -2.1965, -2.1982, -2.1981, -2.1980, -2.1981, -2.1978,
        -2.1965], requires_grad=True)), ('group_dir', tensor([-1.9457, -1.9454, -1.9469, -1.9449, -1.9469, -1.9461, -1.9449],
       requires_grad=True)), ('omim_dir', tensor([-2.1981, -2.1964, -2.1963, -2.1966, -2.1982, -2.1970, -2.1965, -2.1977,
        -2.1982], requires_grad=True)), ('pheno_x_loc_loc', ....

I was expecting these parameters to all be positive and sum to 1, due to the simplex constraint. I am not sure why that constraint is not respected.

Hi, named_parameters() returns unconstrained parameter values, as described in its documentation. To get the constrained value of a parameter, call pyro.param(name).

I’m not sure why your model isn’t working, but setting pyro.enable_validation(True) and looking over our list of SVI tips and tricks is usually a good place to start. For example, tip #7 suggests reducing guide scale parameter initial values (compare to your pheno_x_scl, which is very large relative to the prior), and tip #6 suggests starting with an autoguide rather than a custom guide, which will also generally set sensible initial parameter values automatically.

You should also consider speeding up your model by vectorizing your data plate (via with pyro.plate(...) as opposed to a loop for i in pyro.plate(...))