GMM SVI with relaxed categorical as latent variables

Hi, I am new to pyro and wanted to infer from a GMM using SVI and relaxed categorical latent variables.
Here is my attempt:

import os
from collections import defaultdict
import torch
import numpy as np
import scipy.stats
from torch.distributions import constraints
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import pyro
import pyro.distributions as dist
from pyro import poutine
from pyro.infer.autoguide import AutoDelta
from pyro.optim import Adam
from pyro.infer import SVI, Trace_ELBO, config_enumerate, infer_discrete
smoke_test = "CI" in os.environ
assert pyro.__version__.startswith('1.9.1')

K = 4  # Fixed number of components.

real_assignment = torch.randint(high=K, size=(100,1))
real_scales = torch.tensor(0.1)
real_locs = torch.tensor([0,5,10,20])

data = torch.tensor(scipy.stats.norm.rvs(loc=real_locs[real_assignment], scale=real_scales))[:,0]

print("real_locs", real_locs)
print("real_scales", real_scales)


def model(i):

    temperature = 0.0001 #0.1 / np.log2(i + 0.1)  
    weights = pyro.sample("weights", dist.Dirichlet(0.5 * torch.ones(K)))
    weights[weights<1e-10] = 1e-10

    scale = pyro.sample("scale", dist.Uniform(0.0, 1.0))
    
    with pyro.plate("components", K):
        locs = pyro.sample("locs", dist.Uniform(0.0, 20.0))

    '''with pyro.plate("data", len(data)):
        G = pyro.sample('assignment', dist.Gumbel(loc=torch.zeros(K),scale=torch.ones(K)).to_event(1))
        one_hot_assignment = ((torch.log(weights) + G)/temperature).softmax(-1)
        int_assignment = (one_hot_assignment*torch.arange(K).expand(len(data),-1)).sum(dim=1).int()
        pyro.sample("obs", dist.Normal(locs[int_assignment], scale), obs=data)'''


    with pyro.plate("data", len(data)):
        one_hot_assignment = pyro.sample("assignment", dist.RelaxedOneHotCategoricalStraightThrough(temperature=torch.tensor(temperature), probs=weights))
        int_assignment = (one_hot_assignment*torch.arange(K).expand(len(data),-1)).sum(dim=1).int()
        pyro.sample("obs", dist.Normal(locs[int_assignment], scale), obs=data)

optim = pyro.optim.Adam({"lr": 0.1, "betas": [0.8, 0.99]})

guide = AutoDelta(
        poutine.block(model, expose=["weights", "locs", "scale", 'assignment']))

# setup the inference algorithm
svi = SVI(model, guide, optim, loss=Trace_ELBO())

n_steps = 100
# do gradient steps
for i in range(1,n_steps+1):
    svi.step(i)

# grab the learned variational parameters
map = guide(data)
weights = map['weights']
locs = map['locs']
scale = map['scale']
assignment = map['assignment']

print('weights', weights)
print('locs', locs)
print('scale', scale)
print('assignment', assignment)

I tried both the pyro implementation of this article introducing the Gumbel-Softmax distribution and my own, also tried to decrease the temperature while performing svi as opposed to a fixed temperature value.

This is what the results look like :

Looks weird to me. Can someone please help me understand what is happening here ?