The parameter estimation of categorical distribution

Hi, guys. I`m trying to use MLE to learn the parameter of categorical distribution, which is the probability of each category. However, the SVI method computes very slowly even error as OOM.
In fact, I think SVI is not necessary while couting the number of each category can directly get the optimal solution. Is there any method in pyro can solve this?

BTW, my dataset contains 4 million samples with 1 features, and there is totally 3000 categories in this feature.

Hi @lyy, can you provide a simple code example for your model and inference task? I agree that SVI seems like overkill for a simple counting problem where torch.scatter_add() would suffice. If your model gets more complex and starts to warrant SVI, I would recommend data subsampling.