ValueError: at site "prior_mean_D", invalid log_prob shape

I am getting this invalid log probability tensor shape error:

ValueError: at site "prior_mean_D", invalid log_prob shape
Expected [], actual [9, 7, 9]

In the model() function the following setup is used:

# Three parent discrete nodes
A = pyro.sample("prior_A", 
                dist.Dirichlet(concentration=torch.ones(9) / 9.))
B = pyro.sample("prior_B", 
                dist.Dirichlet(concentration=torch.ones(7) / 7.))
C = pyro.sample("prior_C", 
                dist.Dirichlet(concentration=torch.ones(9) / 9.))

# There is just one node that inherits from all 3. So, there are like
# 9 x 7 x 9 = 567 conditional distributions

# Set priors on the mean and scale for these conditional dists
prior_mean_D = pyro.sample('prior_mean_D',
                           dist.Normal(loc=torch.zeros([9, 7, 9]), 
                                       scale=torch.ones([9, 7, 9]) * 1000))
prior_std_D = pyro.sample('prior_std_D', 
                          dist.Gamma(concentration=torch.ones([9, 7, 9]) * 0.5,
                                     rate=torch.ones([9, 7, 9])))

# Now compute the likelihood
for i in pyro.plate("data_loop", len(data)):
        row = data.iloc[i]  # pandas dataframe
        # Following are all scalars
        A_Val = torch.tensor(row['A'])
        B_val = torch.tensor(row['B'])
        C_val = torch.tensor(row['C'])
        D_val = torch.tensor(row['D'])

        pyro.sample("obs_A_{}".format(i), dist.Categorical(prior_A), obs=A_val)
        pyro.sample("obs_B_{}".format(i), dist.Categorical(prior_B), obs=B_val)
        pyro.sample("obs_C_{}".format(i), dist.Categorical(prior_C), obs=C_val)
                    dist.Normal(loc=prior_mean_D[A_val, B_val, C_val],
                                scale=prior_std_D[A_val, B_val, C_val]),

In the guide function I have something like:

A_dir = pyro.param("A_dir", 
                   torch.ones(9) / 9.0, constraint=constraints.simplex)
A = pyro.sample("prior_A", dist.Dirichlet(concentration=A_dir))

B_dir = pyro.param("B_dir", 
                   torch.ones(9) / 9.0, constraint=constraints.simplex)
B = pyro.sample("prior_B", dist.Dirichlet(concentration=B_dir))

C_dir = pyro.param("C_dir", 
                   torch.ones(9) / 9.0, constraint=constraints.simplex)
C = pyro.sample("prior_C", dist.Dirichlet(concentration=C_dir))

# Now for the conditional D, I have:

D_loc_loc = pyro.param('D_loc_loc', torch.zeros([9, 7, 9]))
D_loc_scl = pyro.param('D_scl', torch.ones([9, 7, 9]) * 1000, constraint=constraints.positive)
D_mean = pyro.sample('prior_mean_D', 
                     dist.Normal(loc=D_loc_loc, scale=D_loc_scl))
D_scl_loc = pyro.param('D_scl_loc', 
                       torch.ones([9, 7, 9]) * 0.5, constraint=constraints.positive)
D_scl_scl = pyro.param('Dscl_scl', 
                       torch.ones([9, 7, 9]), constraint=constraints.positive)
D_std = pyro.sample('prior_std_D, 
                        dist.Gamma(D_scl_loc, D_scl_scl))

Now, I try and optimize the parameters as:

svi = SVI(model,
          optim.Adam({"lr": .05}),

for i in range(2):
    elbo = svi.step(frame)"Elbo loss: {}".format(elbo))

This is where it hits the invalid log_prob_shape error. I feel that perhaps my plate stuff or how I have setup these conditional distributions is not optimal.

I have tried adding these event() statements bu that did not help. I am pretty sure I did not understand their usage correctly.

Would be very grateful of any hint/suggestion on this.

I don’t see where you define prior_D but thats what you sample from in the model. Should you change

                    dist.Normal(loc=prior_D[A_val, B_val, C_val],
                                scale=prior_D[A_val, B_val, C_val]),

to instead sample like this?

                    dist.Normal(loc= prior_mean_D[A_val, B_val, C_val],
                                scale=prior_std_D[A_val, B_val, C_val]),

I also think theres an error in your guide you appear to sample D_std using D_std. I do not think this is intentional.

Ah sorry…this was my copy and paste and formatting error. It is indeed sampled like that. I have fixed those errors.

Looks like adding independent(3) to the distribution objects in models and guide was the solution.