Hi there! I’m using pyro to construct a Gaussian Mixture model. The definition of model seems to work, but when I use SVI for posterior inference, I got an error about the tensor shape. I used print
to see the shape of the tensor, but they are (surprisingly) different. The problem occurs here:
theta = pyro.sample("theta", HalfCauchy(scale=torch.ones(d)).to_event(1))
This random variable is defined within a plate
. I want to define theta
to be a (batchsize, d)
tensor, with batch_shape=batchsize
given by the plate
, and event_shape=d
. When I render my model with
model = model_lkj
pyro.render_model(model, model_args=(data,), render_params=True)
The shape print
showed me is:
(torch.Size([10, 2]), torch.Size([10, 2, 2]))
This is just as expected. However, when I started trainning SVI, the program raised an error:
---> 29 Omega = torch.bmm(theta.sqrt().diag_embed(), omega)
RuntimeError: batch1 must be a 3D tensor
And the shape printed is
(torch.Size([10]), torch.Size([10, 2, 2]))
I’m confused about this error and have no idea where i screwed up. How can I fix this? Thanks for any advice!
My model (with a line of print
):
T=10
def model_lkj(data):
alpha = pyro.param("alpha", torch.tensor([1.0]))
with pyro.plate("sticks", T-1):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("component", T):
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), 5 * torch.eye(d)))
theta = pyro.sample("theta", HalfCauchy(scale=torch.ones(d)).to_event(1))
omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
print((theta.shape, omega.shape))
Omega = torch.bmm(theta.sqrt().diag_embed(), omega)
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(mix_weights(beta)))
pyro.sample("obs", MultivariateNormal(mu[z], scale_tril=Omega[z]), obs=data)
And my guide:
def guide_lkj(data):
kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
tau = pyro.param('tau', lambda: MultivariateNormal(torch.zeros(d), 3 * torch.eye(d)).sample([T]))
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)
a = pyro.param('a', lambda: Uniform(0, 2).sample([T]), constraint=constraints.positive)
b = pyro.param('b', lambda: Uniform(0, 2).sample([T]), constraint=constraints.positive)
c = pyro.param('c', lambda: Gamma(1, 1).sample([T]))
with pyro.plate("sticks", T-1):
q_beta = pyro.sample("beta", Beta(torch.ones(T-1), kappa))
with pyro.plate("component", T):
q_mu = pyro.sample("mu", MultivariateNormal(tau, torch.eye(d)))
q_theta = pyro.sample("theta", Gamma(a, b))
q_omega = pyro.sample('omega', LKJCholesky(d, concentration=c))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(phi))
And my training loop:
model = model_lkj
guide = guide_lkj
optim = Adam({'lr': 0.01, 'betas': [0.9, 0.999]})
svi = SVI(model, guide, optim, loss=Trace_ELBO())
losses = []
def train(num_iterations):
pyro.clear_param_store()
for j in tqdm(range(num_iterations)):
loss = svi.step(data)
losses.append(loss)
def truncate(alpha, centers, weights):
threshold = alpha**-1 / 100.
true_centers = centers[weights > threshold]
true_weights = weights[weights > threshold] / torch.sum(weights[weights > threshold])
return true_centers, true_weights
alpha = 0.1
train(1000)