I have a model, which I define as follows:
def model(data: pd.DataFrame):
prior_gene = pyro.sample("prior_gene",
dist.Dirichlet(concentration=torch.ones(9) / 9.))
prior_group = pyro.sample("prior_group",
dist.Dirichlet(concentration=torch.ones(7) / 7.))
prior_omim = pyro.sample("prior_omim",
dist.Dirichlet(concentration=torch.ones(9) / 9.))
prior_mean_pheno_x = pyro.sample('prior_mean_pheno_x',
dist.Normal(loc=torch.zeros([9, 7, 9]),
scale=torch.ones([9, 7, 9]) * 10).independent(3))
prior_std_pheno_x = pyro.sample('prior_std_pheno_x',
dist.HalfCauchy(torch.ones([9, 7, 9])).independent(3))
for i in pyro.plate("data_loop", len(data)):
row = data.iloc[i]
current_gene = row['gene']
current_group = row['group']
current_omim = row['omim_id']
pyro.sample("obs_genes_{}".format(i), dist.Categorical(prior_gene), obs=torch.tensor(current_gene))
pyro.sample("obs_group_{}".format(i), dist.Categorical(prior_group), obs=torch.tensor(current_group))
pyro.sample("obs_omims_{}".format(i), dist.Categorical(prior_omim), obs=torch.tensor(current_omim))
# Conditional nodes
pyro.sample("obs_pheno_x_{}".format(i),
dist.Normal(loc=prior_mean_pheno_x[current_gene, current_group, current_omim],
scale=prior_std_pheno_x[current_gene, current_group, current_omim]),
obs=torch.tensor(row['pheno_x']))
Now I specify the corresponding guide function as:
def guide(data: pd.DataFrame):
gene_dir = pyro.param("gene_dir",
torch.ones(9) / 9.0, constraint=constraints.simplex)
prior_gene = pyro.sample("prior_gene", dist.Dirichlet(concentration=gene_dir))
group_dir = pyro.param("group_dir",
torch.ones(7) / 7.0, constraint=constraints.simplex)
prior_group = pyro.sample("prior_group", dist.Dirichlet(concentration=group_dir))
omim_dir = pyro.param("omim_dir",
torch.ones(9) / 9.0, constraint=constraints.simplex)
prior_omim = pyro.sample("prior_omim",
dist.Dirichlet(concentration=omim_dir))
pheno_x_loc_loc = pyro.param('pheno_x_loc_loc', torch.zeros([9, 7, 9]))
pheno_x_loc_scl = pyro.param('pheno_x_scl', torch.ones([9, 7, 9]) * 200, constraint=constraints.positive)
prior_mean_pheno_x = pyro.sample('prior_mean_pheno_x',
dist.Normal(loc=pheno_x_loc_loc, scale=pheno_x_loc_scl).independent(3))
pheno_x_scale = pyro.param('pheno_x_scl_scl', torch.ones([9, 7, 9]) * 2, constraint=constraints.positive)
prior_std_pheno_x = pyro.sample('prior_std_pheno_x',
dist.Chi2(pheno_x_scale).independent(3))
So, I have defined this simplex constraint on the categorical distribution parameters.
I finally fit the model as:
svi = SVI(model_delphi,
guide,
optim.Adam({"lr": .0005, "betas": (0.93, 0.999)}),
loss=Trace_ELBO(max_iarange_nesting=1))
for i in range(num_iters):
elbo = svi.step(frame)
print("Elbo loss: {}".format(elbo))
So, my loss is absolutely haywire. It starts high and wildly oscillates around a very high number.
Also, when I look at the fitted parameters with:
print(pyro.get_param_store().named_parameters())
I get something like:
dict_items([('gene_dir', tensor([-2.1962, -2.1982, -2.1965, -2.1982, -2.1981, -2.1980, -2.1981, -2.1978,
-2.1965], requires_grad=True)), ('group_dir', tensor([-1.9457, -1.9454, -1.9469, -1.9449, -1.9469, -1.9461, -1.9449],
requires_grad=True)), ('omim_dir', tensor([-2.1981, -2.1964, -2.1963, -2.1966, -2.1982, -2.1970, -2.1965, -2.1977,
-2.1982], requires_grad=True)), ('pheno_x_loc_loc', ....
I was expecting these parameters to all be positive and sum to 1, due to the simplex constraint. I am not sure why that constraint is not respected.