I am getting this invalid log probability tensor shape error:
ValueError: at site "prior_mean_D", invalid log_prob shape
Expected [], actual [9, 7, 9]
In the model()
function the following setup is used:
# Three parent discrete nodes
A = pyro.sample("prior_A",
dist.Dirichlet(concentration=torch.ones(9) / 9.))
B = pyro.sample("prior_B",
dist.Dirichlet(concentration=torch.ones(7) / 7.))
C = pyro.sample("prior_C",
dist.Dirichlet(concentration=torch.ones(9) / 9.))
# There is just one node that inherits from all 3. So, there are like
# 9 x 7 x 9 = 567 conditional distributions
# Set priors on the mean and scale for these conditional dists
prior_mean_D = pyro.sample('prior_mean_D',
dist.Normal(loc=torch.zeros([9, 7, 9]),
scale=torch.ones([9, 7, 9]) * 1000))
prior_std_D = pyro.sample('prior_std_D',
dist.Gamma(concentration=torch.ones([9, 7, 9]) * 0.5,
rate=torch.ones([9, 7, 9])))
# Now compute the likelihood
for i in pyro.plate("data_loop", len(data)):
row = data.iloc[i] # pandas dataframe
# Following are all scalars
A_Val = torch.tensor(row['A'])
B_val = torch.tensor(row['B'])
C_val = torch.tensor(row['C'])
D_val = torch.tensor(row['D'])
pyro.sample("obs_A_{}".format(i), dist.Categorical(prior_A), obs=A_val)
pyro.sample("obs_B_{}".format(i), dist.Categorical(prior_B), obs=B_val)
pyro.sample("obs_C_{}".format(i), dist.Categorical(prior_C), obs=C_val)
pyro.sample("obs_D_{}".format(i),
dist.Normal(loc=prior_mean_D[A_val, B_val, C_val],
scale=prior_std_D[A_val, B_val, C_val]),
obs=D_val)
In the guide function I have something like:
A_dir = pyro.param("A_dir",
torch.ones(9) / 9.0, constraint=constraints.simplex)
A = pyro.sample("prior_A", dist.Dirichlet(concentration=A_dir))
B_dir = pyro.param("B_dir",
torch.ones(9) / 9.0, constraint=constraints.simplex)
B = pyro.sample("prior_B", dist.Dirichlet(concentration=B_dir))
C_dir = pyro.param("C_dir",
torch.ones(9) / 9.0, constraint=constraints.simplex)
C = pyro.sample("prior_C", dist.Dirichlet(concentration=C_dir))
# Now for the conditional D, I have:
D_loc_loc = pyro.param('D_loc_loc', torch.zeros([9, 7, 9]))
D_loc_scl = pyro.param('D_scl', torch.ones([9, 7, 9]) * 1000, constraint=constraints.positive)
D_mean = pyro.sample('prior_mean_D',
dist.Normal(loc=D_loc_loc, scale=D_loc_scl))
D_scl_loc = pyro.param('D_scl_loc',
torch.ones([9, 7, 9]) * 0.5, constraint=constraints.positive)
D_scl_scl = pyro.param('Dscl_scl',
torch.ones([9, 7, 9]), constraint=constraints.positive)
D_std = pyro.sample('prior_std_D,
dist.Gamma(D_scl_loc, D_scl_scl))
Now, I try and optimize the parameters as:
svi = SVI(model,
guide,
optim.Adam({"lr": .05}),
loss=Trace_ELBO())
for i in range(2):
elbo = svi.step(frame)
logging.info("Elbo loss: {}".format(elbo))
This is where it hits the invalid log_prob_shape
error. I feel that perhaps my plate stuff or how I have setup these conditional distributions is not optimal.
I have tried adding these event()
statements bu that did not help. I am pretty sure I did not understand their usage correctly.
Would be very grateful of any hint/suggestion on this.