Hi
Recently I want to use pyro to infer the parameters within a DP process where the observed data is from a mixture of multinomial distribution. I tried to adjust the model which was shown in the example, however I got the following errors:
ValueError: Error while computing log_prob at site 'obs':
The right-most size of value must match event_shape: torch.Size([1113]) vs torch.Size([71]).
Trace Shapes:
Param Sites:
Sample Sites:
beta dist 49 |
value 49 |
log_prob 49 |
alpha dist 50 | 71
value 50 | 71
log_prob 50 |
z dist 1113 |
value 1113 |
log_prob 1113 |
obs dist 1113 | 71
value | 1113
Below is my code
def mix_weights(beta):
beta1m_cumprod = (1 - beta).cumprod(-1)
return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)
def model(data_class):
with pyro.plate("beta_plate", T-1):
beta = pyro.sample("beta", Beta(1, 1))
with pyro.plate("alpha_plate", T):
alpha = pyro.sample("alpha", Dirichlet(1/Nct * torch.ones(Nct)))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(mix_weights(beta)))
pyro.sample("obs", Multinomial(N, alpha[z]), obs=data_class)
def guide(data):
kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
tau = pyro.param('tau', lambda: Multinomial(Nct, (torch.ones(Nct)/Nct)).sample_n(T), constraint=constraints.positive)
phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)
with pyro.plate("beta_plate", T-1):
q_beta = pyro.sample("beta", Beta(torch.ones(T-1), kappa))
with pyro.plate("alpha_plate", T):
q_alpha = pyro.sample("alpha", Dirichlet(tau+1))
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(phi))
where N is the data length, T is the number of clusters and Nct is the number of categories.
N = 1113
T = 50
Nct = 71
Very appreciate your help!