Hi! I’m currently constructing a gaussian mixture model, whose component parameter of covariance is sampled from another mixture. To realize this idea, I used two sets of discrete variables for indexing, and used enumeration for inference. The svi
works well for inference, and convergence is reached.
However, I run into errors when I’m trying to perform classifiaction:
ValueError: Error while packing tensors at site 'inv_Sigma':
Invalid tensor shape.
Allowed dims: -1
Actual shape: (7, 19)
u dist 19 |
value 7 1 |
log_prob 7 19 |
inv_Sigma dist 7 19 | 2 2
value 19 | 2 2
log_prob 7 19 |
where 7 and 19 are the number of parameter mixture components and sample mixture components respectively.
It seems to me that this shape at dim=-2
of shape 7, is a result of enumeration. However, what I’m trying to do is simply getting a MAP estimation or prediction. So I was wondering: what’s wrong with my classification function?
guide_trace = poutine.trace(global_guide).get_trace(data) # record the globals
trained_model = poutine.replay(model, trace=guide_trace) # replay the globals
def classifier(data, temperature=0): #random assignment if temperature=1
inferred_model = infer_discrete(trained_model, temperature=temperature,
first_available_dim=-2)
trace = poutine.trace(inferred_model).get_trace(data) #error raised here
return trace.nodes["z"]["value"]
Here is my model, if it provides any information:
# DPMM Model in pyro
T_comp = 7
T_mix = 19
alpha_0 = 0.1
gamma_0 = 0.1
def mix_weights(beta):
beta1m_cumprod = (1 - beta).cumprod(-1)
return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)
@config_enumerate
def refresh(data=None):
alpha = pyro.param("alpha", torch.tensor([alpha_0]))
gamma = pyro.param("gamma", torch.tensor([gamma_0]))
var = pyro.param("var", lambda:Chi2(torch.ones(d)).to_event(1).sample([T_mix]), constraint=constraints.positive)
with pyro.plate("component_sticks", T_comp-1):
beta_comp = pyro.sample("beta_comp", Beta(1, gamma))
with pyro.plate("component", T_comp):
# component of prior for covariance
nu = pyro.sample("nu", Uniform(d+2, 2*d+2))
theta = pyro.sample("theta", Chi2(df=torch.ones(d)*(d+1)).to_event(1))
omega = pyro.sample('omega', LKJCholesky(d, concentration=1))
Omega = torch.bmm(theta.sqrt().diag_embed(), omega)
with pyro.plate("mixture_sticks", T_mix-1):
beta_mix = pyro.sample("beta_mix", Beta(1, alpha))
with pyro.plate("mixture", T_mix):
u = pyro.sample("u", Categorical(mix_weights(beta_comp)), infer={'enumerate': 'parallel'})
inv_Sigma = pyro.sample("inv_Sigma", Wishart(df=nu[u], scale_tril=Omega[u]))
mu = pyro.sample("mu", MultivariateNormal(torch.zeros(d), scale_tril=torch.bmm(var.sqrt().diag_embed(), torch.eye(d).repeat(T_mix,1,1))))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(mix_weights(beta_mix)), infer={'enumerate': 'parallel'})
pyro.sample("obs", MultivariateNormal(mu[z], precision_matrix=inv_Sigma[z]), obs=data)