Invalid log_prob shape

Hi, I keep getting an error for invalid log_prob shapes, but I’m not seeing the mismatch when I print out tensor shapes using poutine.trace. I’ve attached a toy example of the code producing the error below. I tried using plates as well (commented out below), but that didn’t work either. Thank you for any insights!

import numpy as np
import torch
import torch.distributions as tdist
from torch.distributions import constraints
import pyro
import pyro.optim as poptim
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO

pyro.enable_validation(True)
pyro.set_rng_seed(1)

def model():

    # declare plates
    #ind_plt = pyro.plate('inds', 50)

    #with ind_plt:
    #    d = dist.Dirichlet(torch.ones([50, 10]))
    #    theta_b = pyro.sample("theta_b", d) # [n_inds, k_b]
    d = dist.Dirichlet(torch.ones([50,10]))
    theta_b = pyro.sample("theta_b", d)
    print('model: ', theta_b.shape)
    print(d.batch_shape)
    print(d.event_shape)
    print(d.event_dim)
    print(d.log_prob(theta_b).shape)

##################################################################

def guide():

    # declare plates
    #ind_plt = pyro.plate('inds', 50)

    omega_b = pyro.param("omega_b", (torch.ones([50, 10])/10), constraint=constraints.simplex)
    #with ind_plt:
    #    theta_b = pyro.sample("theta_b", dist.Delta(omega_b).to_event(1)) # [n_inds, k_b]
    d = dist.Delta(omega_b).to_event(1)
    theta_b = pyro.sample("theta_b", d)
    print('guide ', theta_b.shape)
    print(d.batch_shape)
    print(d.event_shape)
    print(d.event_dim)
    print(d.log_prob(theta_b).shape)

##################################################################
trace = pyro.poutine.trace(model).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

trace = pyro.poutine.trace(guide).get_trace()
trace.compute_log_prob()
print(trace.format_shapes())

pyro.clear_param_store()

opt = poptim.Adam({"lr": 0.01})

svi = SVI(model, guide, opt, loss=Trace_ELBO())

for epoch in range(2):
    svi.step()

there are different ways you might do this. one is to use to_event(1) in the model and to_event(2) in the guide.