DP process for multinomial ditribution

Hi
Recently I want to use pyro to infer the parameters within a DP process where the observed data is from a mixture of multinomial distribution. I tried to adjust the model which was shown in the example, however I got the following errors:

ValueError: Error while computing log_prob at site 'obs':
The right-most size of value must match event_shape: torch.Size([1113]) vs torch.Size([71]).
Trace Shapes:            
 Param Sites:            
Sample Sites:            
    beta dist   49 |     
        value   49 |     
     log_prob   49 |     
   alpha dist   50 |   71
        value   50 |   71
     log_prob   50 |     
       z dist 1113 |     
        value 1113 |     
     log_prob 1113 |     
     obs dist 1113 |   71
        value      | 1113

Below is my code

def mix_weights(beta):
    beta1m_cumprod = (1 - beta).cumprod(-1)
    return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)

def model(data_class):
    with pyro.plate("beta_plate", T-1):
        beta = pyro.sample("beta", Beta(1, 1))

    with pyro.plate("alpha_plate", T):
        alpha = pyro.sample("alpha", Dirichlet(1/Nct * torch.ones(Nct)))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(mix_weights(beta)))
        pyro.sample("obs", Multinomial(N, alpha[z]), obs=data_class)

def guide(data):
    kappa = pyro.param('kappa', lambda: Uniform(0, 2).sample([T-1]), constraint=constraints.positive)
    tau = pyro.param('tau', lambda: Multinomial(Nct, (torch.ones(Nct)/Nct)).sample_n(T), constraint=constraints.positive)
    phi = pyro.param('phi', lambda: Dirichlet(1/T * torch.ones(T)).sample([N]), constraint=constraints.simplex)

    with pyro.plate("beta_plate", T-1):
        q_beta = pyro.sample("beta", Beta(torch.ones(T-1), kappa))

    with pyro.plate("alpha_plate", T):
        q_alpha = pyro.sample("alpha", Dirichlet(tau+1))

    with pyro.plate("data", N):
        z = pyro.sample("z", Categorical(phi))

where N is the data length, T is the number of clusters and Nct is the number of categories.

N = 1113
T = 50
Nct = 71

Very appreciate your help!

Hi @MorSauron . What is the shape of data_class? From your code it looks like it should be (1113, 71) but it is (1113,) and that’s why you get an error message.

Hi @ordabayev Thanks for your help!
I reorganized my data, and the data now is a matrix with 5 rows(samples) and 71 columns(categories), however, I still got an error. Cloud you please tell me how to solve it? Very appreciate your help!

ValueError: Error while computing log_prob at site 'obs':
Expected value argument (Tensor of shape (5, 71)) to be within the support (Multinomial()) of the distribution Multinomial(), but found invalid values:
tensor([[  8,  17,  17,  10, 129,  64,  17,  34,   1,  12,  56,   1,  34,   1,
           4, 110,   2,   0, 364,  18,  17,   6,   0,  23,   5,   0,   1,   2,
           0,   3,   0,   1,   0,   2,   2,   0,   6,   0,   1,   0,   0,   9,
           1,   2, 109,   3,   7,   7,   1,   0,   0,   2,   0,   0,   0,   1,
           0,   1,   0,   1,   0,   0,   0,   0,   0,   0,   0,   0,   1,   0,
           0],
        [ 92, 515, 326,  66,  91,  93, 131,  11,  77,  74, 340,   6,  89,  50,
          69, 562,  42,   6,  95,  26, 207,  28,   1,  84,  57, 753, 177,  98,
           1,  27,  17, 490,  96,  43, 109,   1,   7,   0,   0,   0,   1,  10,
           0,   0,   4,   0,   2,   1,   5,   2,   0,   4,   4,   0,   2,   1,
           0,  12,   1,   7,  17,   0,   0,   3,   2,   5,   2,   0,   7,  21,
          12],
        [ 79,   1, 133,  24,  25,  90,  14,   7,  71, 121,  89,  29,   2,   3,
          28, 442,   0,   1,  20,  18, 125,   4,  42, 112,   2,   8, 254,  41,
           1,   0,   0,   9,   0,   0,  46,   1,  67,   0,  84,   3,   1,  42,
          37,   3,  71,  71,  79, 117,  84,   2,   0,   3,   0,   0,   5,  89,
           9,   3,   0,   2,   0,   0,   0,   1,   0,  31, 118,   7,  19,   1,
           0],
        [  0,   0,   1,   1,   2,   2,   1,   1,   1,   3,   7,   2, 108,   0,
           1,  17,   0,   0,  56,   9,   2,   1,   7,   4,   0,   1,  11,   5,
           0,   0,   0,   1,   1,   0, 291,   4,  10,  14,   5,   0,   2,   3,
           6,   0,  69,   0,  20, 319,  16,  67,   0,   0,   0,   0,   1,  41,
           1,   0,   0,   3,   0,   0,   0,   0,   0,   6, 128,   3,  11,   0,
           0],
        [127,  90, 402,  18,  69,  40,  27,   4, 170, 185, 107,   9,   9,  22,
         171, 304,  78,  49,  78,  57, 324, 181,  18,  71,  17, 113, 586, 193,
           1,   0,   0, 156,  35,  21, 157,   0,  55,   0,   9,   4,   5,  66,
          18,   0,  30,   2,  26,   5,  12,   5,   0,   4,   1,   1,   1,  11,
           3,  25,   7,  10,   0,   0,   0,   2,   2,   9,   3,   1,   7,   6,
           6]])
Trace Shapes:        
 Param Sites:        
Sample Sites:        
    beta dist 49 |   
        value 49 |   
     log_prob 49 |   
   alpha dist 50 | 71
        value 50 | 71
     log_prob 50 |   
       z dist  5 |   
        value  5 |   
     log_prob  5 |   
     obs dist  5 | 71
        value  5 | 71

When the log_prob is calculated it tries to validate that the value is within the support. In particular, in

the sum of values in data_class need to be less than or equal to N: data_class.sum(-1) <= N. You can find that condition for Multinomial distribution here.

Also note that for Multinomial distribution the value of total_counts is not used in log_prob calculation: Probability distributions - torch.distributions — PyTorch 2.2 documentation

You can just set total_count to a larger number or set validate_args=False.

Thanks a lot ! my model can run perfectly !