In pyro, it is possible to use a model containing parameters, a null guide, and SVI to compute MLE for a given model and data, per this tutorial. When I tried the same approach in numpyro, I get an error: TypeError: unsupported operand type(s) for *: ‘dict’ and ‘DeviceArray’. It seems to happen because an assumption in Numpyro’s SVI class is violated by having a model containing parameters, whereas in pyro this works.
Is there a straightforward way to use numpyro to compute MLE (and MAP) or is this unsupported?
Thank you, there is complete code to reproduce the problem I see, adapted from the tutorial I linked above.
import numpyro.distributions.constraints as constraints
import numpyro as pyro
import numpy as np
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from jax import random
data = np.zeros(10)
data[0:6] = 1.0
def model_mle(data):
# note that we need to include the interval constraint;
# in original_model() this constraint appears implicitly in
# the support of the Beta distribution.
f = pyro.param("latent_fairness", np.array(0.5),
constraint=constraints.unit_interval)
print("f:",f)
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Bernoulli(f), obs=data)
pyro.render_model(model_mle, model_args=(data,), render_distributions=True, render_params=True)
def guide_mle(data):
pass
def train(model, guide, lr=0.005, n_steps=201):
#pyro.clear_param_store()
adam_params = {"lr": lr}
adam = pyro.optim.Adam(adam_params)
svi = SVI(model, guide, adam, loss=Trace_ELBO())
rng_key = random.PRNGKey(0)
svi_state = svi.init(rng_key,data)
for step in range(n_steps):
svi_state,loss = svi.update(svi_state,data)
if step % 50 == 0:
print('[iter {}] loss: {:.4f}'.format(step, loss))
train(model_mle, guide_mle)
I would recommend using optax (you can search for some examples here and there in numpyro) rather than numpyro.optim, which depends on old jax optimizers.