LDA model with continuous values

Hi, I want to change the lda.py to the situation where the observations are continuous. The changes are that I replace Catagorical distribution with Gaussian distribution. Therefore, the parameters that I have to learn is topic weights, mean of each Gaussian, and var of each Gaussian.
I have finished the code, but I find it is unstable and will lead to “nan” loss if the data is large.
Could anyone find the problem?

from __future__ import absolute_import, division, print_function
import argparse
import functools
import logging

import torch
from torch import nn
from torch.distributions import constraints

import pyro
import pyro.distributions as dist
from pyro.infer import SVI, JitTraceEnum_ELBO, TraceEnum_ELBO
from pyro.optim import Adam
import numpy as np
logging.basicConfig(format='%(relativeCreated) 9d %(message)s', level=logging.INFO)


# This is a fully generative model of a batch of documents.
# data is a [num_words_per_doc, num_documents] shaped array of word values
# that all documents have the same number of words.
def model(data=None, args=None, batch_size=None):
    # Globals.
    with pyro.plate("topics", args.num_topics):
        topic_weights = pyro.sample("topic_weights", dist.Gamma(1. / args.num_topics, 1.))
        topic_mean = pyro.sample("topic_mean",dist.Gamma(1, 1.))
        topic_varriance = pyro.sample("topic_varriance",dist.Gamma(1, 1.))
    # Locals.
    with pyro.plate("documents", args.num_docs) as ind:
        if data is not None:
            with pyro.util.ignore_jit_warnings():
                assert data.shape == (args.num_words_per_doc, args.num_docs)

            data = data[:, ind]

        doc_topics = pyro.sample("doc_topics", dist.Dirichlet(topic_weights)) #sum of doc_topics is 1
        with pyro.plate("words", args.num_words_per_doc):

            word_topics = pyro.sample("word_topics", dist.Categorical(doc_topics),infer={"enumerate": "parallel"})
            
            data = pyro.sample("doc_words", dist.Normal(topic_mean[word_topics],topic_varriance[word_topics]),
                               obs=data)
    return topic_weights, topic_mean,topic_varriance, data


# We will use amortized inference of the local topic variables, achieved by a
# multi-layer perceptron. We'll wrap the guide in an nn.Module.
def make_predictor(args):
    layer_sizes = ([args.num_words_per_doc] +
                   [int(s) for s in args.layer_sizes.split('-')] +
                   [args.num_topics])
    logging.info('Creating MLP with sizes {}'.format(layer_sizes))
    layers = []
    for in_size, out_size in zip(layer_sizes, layer_sizes[1:]):
        layer = nn.Linear(in_size, out_size)
        layer.weight.data.normal_(0, 0.001)
        layer.bias.data.normal_(0, 0.001)
        layers.append(layer)
        layers.append(nn.Sigmoid()) # cannot guareetee the sum is 1.
    return nn.Sequential(*layers)


def parametrized_guide(predictor, data, args, batch_size=None):
    # Use a conjugate guide for global variables.
    topic_weights_posterior = pyro.param(
            "topic_weights_posterior",
            lambda: torch.ones(args.num_topics) / args.num_topics,
            constraint=constraints.positive)
    topic_mean_mean_posterior = pyro.param(
            "topic_mean_mean_posterior",
            lambda: torch.ones(args.num_topics) / args.num_topics,
            constraint=constraints.positive)
    topic_varriance_mean_posterior = pyro.param(
            "topic_varriance_mean_posterior",
            lambda: torch.ones(args.num_topics) / args.num_topics,
            constraint=constraints.positive)
    with pyro.plate("topics", args.num_topics):
        tw=pyro.sample("topic_weights", dist.Gamma(topic_weights_posterior, 1.))
        pyro.sample("topic_mean", dist.Gamma(topic_mean_mean_posterior,1))
        pyro.sample("topic_varriance", dist.Gamma(topic_varriance_mean_posterior,1))
    # Use an amortized guide for local variables.
    pyro.module("predictor", predictor)
    with pyro.plate("documents", args.num_docs, batch_size) as ind:
        doc_topics = predictor(data[:, ind].transpose(0, 1))
        pyro.sample("doc_topics", dist.Delta(doc_topics, event_dim=1))

def main(args):
    logging.info('Generating data')
    pyro.set_rng_seed(0)
    # We can generate synthetic data directly by calling the model.
    true_topic_weights, true_topic_mean,true_topic_varriance, data = model(args=args)

    # We'll train using SVI.
    logging.info('-' * 40)
    logging.info('Training on {} documents'.format(args.num_docs))
    predictor = make_predictor(args)
    guide = functools.partial(parametrized_guide, predictor)
    Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
    
    elbo = Elbo(max_plate_nesting=2)
    optim = Adam({'lr': args.learning_rate})
    svi = SVI(model, guide, optim, elbo)
    logging.info('Step\tLoss')
    if args.cuda=='store_true':
        self.cuda()
    for step in range(args.num_steps):
        if args.cuda=='store_true':
           data = data.cuda()
        loss = svi.step(data, args=args, batch_size=args.batch_size)
        if step % 10 == 0:
            logging.info('{: >5d}\t{}'.format(step, loss))
    loss = elbo.loss(model, guide, data, args=args)

    logging.info('final loss = {}'.format(loss))


if __name__ == '__main__':
    assert pyro.__version__.startswith('0.3.0')
    parser = argparse.ArgumentParser(description="Amortized Latent Dirichlet Allocation")
    parser.add_argument("-t", "--num-topics", default=5, type=int)
    parser.add_argument("-w", "--num-words", default=100, type=int)
    parser.add_argument("-d", "--num-docs", default=100, type=int)
    parser.add_argument("-wd", "--num-words-per-doc", default=100, type=int)
    parser.add_argument("-n", "--num-steps", default=1000, type=int)
    parser.add_argument("-l", "--layer-sizes", default="10-10")
    parser.add_argument("-lr", "--learning-rate", default=0.00001, type=float)
    parser.add_argument("-b", "--batch-size", default=100, type=int)
    parser.add_argument('--jit', action='store_true')
    parser.add_argument('--cuda', action='store_true', default=True, help='whether to use cuda')
    args = parser.parse_args()
    main(args)