Categorical sample out-of-bounds

I’m trying to directly implement LDA, not worried about performance at this point. With some extra prints for debugging, my model is

def model(data):
    α = pyro.param("α", t.tensor(0.1),constraint=constraints.positive)
    β = [pyro.param(f"β_{z}",t.ones(V)/V, constraint=constraints.simplex) for z in range(K)]
    print(f"α={α}")
    for d in pyro.irange("documents", D):
        print(f"d={d}")
        θ = pyro.sample(f"θ_{d}", dist.Dirichlet(α * t.ones(K)))
        print(f"θ={θ}")
        data_d = data[d]
        for n in pyro.irange(f"loop_{d}",len(data_d)):
            print(f"n={n}")
            z = pyro.sample(f"z_{d},{n}", dist.Categorical(θ))
            print(f"z={z}")
            print(f"θ[z]={θ[z]}")
            pyro.sample(f"w_{d},{n}", dist.Categorical(β[z]), obs=data_d[n])

In the output, the Categorical seems to be sampling well outside the bounds:

α=0.09999999403953552
d=0
θ=tensor([ 0.4130,  0.2750,  0.1093,  0.1633,  0.0394])
n=0
z=4
θ[z]=0.03943357244133949
n=1
z=10

Is this a bug, or am I missing something?

EDIT:
I guess it’s not clear why this is out of bounds. In the sample z ~ Categorical(θ), θ has length 5. In the first time through the n loop, we get z=4, which is fine. But the next time through we get z=10 despite θ not having changed. The subsequent call to print(f"θ[z]={θ[z]}") throws the error.

sorry, i fail to see how it is sampling out of bounds. from your snippet above:

θ=tensor([ 0.4130,  0.2750,  0.1093,  0.1633,  0.0394])
z=4 . # sample from [0, 4] with probabilities corresponding to your theta above
θ[z]=0.03943357244133949  # probability 4 was sampled is θ[4] = 0.0394

I’ve added some detail, hope this makes it more clear. Thanks!

are you using svi? if so what does your guide look like? in svi, the samples come from the guide and the model is just used for log probability calculations.

Oh I see, I guess this last line is throwing things off. Thanks!

D = 2    #documents
K = 5    #topics
V = 15   #vocabulary size

def guide(data):
    for d in pyro.irange("documents", D):
        γ = pyro.param(f"γ_{d}", t.ones(K),constraint=constraints.positive)
        θ = pyro.sample(f"θ_{d}", dist.Dirichlet(γ))
        data_d = data[d]
        for n in pyro.irange(f"loop_{d}",len(data_d)):
            φ = pyro.param(f"φ_{d},{n}", t.ones(V)/V, constraint=constraints.simplex)
            z = pyro.sample(f"z_{d},{n}", dist.Categorical(φ))