Thank you, there is complete code to reproduce the problem I see, adapted from the tutorial I linked above.
import numpyro.distributions.constraints as constraints
import numpyro as pyro
import numpy as np
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from jax import random
data = np.zeros(10)
data[0:6] = 1.0
def model_mle(data):
# note that we need to include the interval constraint;
# in original_model() this constraint appears implicitly in
# the support of the Beta distribution.
f = pyro.param("latent_fairness", np.array(0.5),
constraint=constraints.unit_interval)
print("f:",f)
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Bernoulli(f), obs=data)
pyro.render_model(model_mle, model_args=(data,), render_distributions=True, render_params=True)
def guide_mle(data):
pass
def train(model, guide, lr=0.005, n_steps=201):
#pyro.clear_param_store()
adam_params = {"lr": lr}
adam = pyro.optim.Adam(adam_params)
svi = SVI(model, guide, adam, loss=Trace_ELBO())
rng_key = random.PRNGKey(0)
svi_state = svi.init(rng_key,data)
for step in range(n_steps):
svi_state,loss = svi.update(svi_state,data)
if step % 50 == 0:
print('[iter {}] loss: {:.4f}'.format(step, loss))
train(model_mle, guide_mle)
edit: formatting mistake