Problem with tensor shape: inconsistent results

Hi there! I’m using pyro to construct a Gaussian Mixture model. The definition of model seems to work, but when I use SVI for posterior inference, I got an error about the tensor shape. I used print to see the shape of the tensor, but they are (surprisingly) different. The problem occurs here:

 theta = pyro.sample("theta", HalfCauchy(scale=torch.ones(d)).to_event(1))

This random variable is defined within a plate. I want to define theta to be a (batchsize, d) tensor, with batch_shape=batchsize given by the plate, and event_shape=d. When I render my model with

model = model_lkj
pyro.render_model(model, model_args=(data,), render_params=True)

The shape print showed me is:

(torch.Size([10, 2]), torch.Size([10, 2, 2]))

This is just as expected. However, when I started trainning SVI, the program raised an error:

---> 29 Omega = torch.bmm(theta.sqrt().diag_embed(), omega)
RuntimeError: batch1 must be a 3D tensor

And the shape printed is

(torch.Size([10]), torch.Size([10, 2, 2]))

I’m confused about this error and have no idea where i screwed up. How can I fix this? Thanks for any advice!

My model (with a line of print):

def model_lkj(data):
    alpha = pyro.param("alpha", torch.tensor([1.0]))
    with pyro.plate("sticks", T-1):
        beta = pyro.sample("beta", Beta(1, alpha))

    with pyro.plate("component", T):
        mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), 5 * torch.eye(d)))
        theta = pyro.sample("theta", HalfCauchy(scale=torch.ones(d)).to_event(1))
        omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
        print((theta.shape, omega.shape))
        Omega = torch.bmm(theta.sqrt().diag_embed(), omega)

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=Omega[z]), obs=data)

And my guide:

def guide_lkj(data):
    kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(d), 3 * torch.eye(d)).sample([T]))
    phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)
    a = pyro.param('a', lambda: Uniform(0, 2).sample([T]), constraint=constraints.positive)
    b = pyro.param('b', lambda: Uniform(0, 2).sample([T]), constraint=constraints.positive)
    c = pyro.param('c', lambda: Gamma(1, 1).sample([T]))

    with pyro.plate("sticks", T-1):
        q_beta = pyro.sample("beta", Beta(torch.ones(T-1), kappa))
    with pyro.plate("component", T):
        q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(d)))
        q_theta = pyro.sample("theta", Gamma(a, b))
        q_omega = pyro.sample('omega', LKJCholesky(d, concentration=c))
    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(phi))

And my training loop:

model = model_lkj
guide = guide_lkj
optim = Adam({'lr': 0.01, 'betas': [0.9, 0.999]})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
losses = []

def train(num_iterations):
    for j in tqdm(range(num_iterations)):
        loss = svi.step(data)

def truncate(alpha, centers, weights):
    threshold = alpha**-1 / 100.
    true_centers = centers[weights > threshold]
    true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
    return true_centers, true_weights

alpha = 0.1

Besides, the model can perform inference with MCMC normally. During MCMC, print outputs the same shape of theta as in rendering (torch.Size([10, 2]), which is also different from what I got in SVI.
So I’m wondering, what’s happening in svi.step() that makes the model behave differently?

Oh I actucally solved this! This is something not with my model but with my guide:
The guide for theta has invalid shape of torch.size([10]). This is the cause.