MLE/MAP in numpyro

In pyro, it is possible to use a model containing parameters, a null guide, and SVI to compute MLE for a given model and data, per this tutorial. When I tried the same approach in numpyro, I get an error: TypeError: unsupported operand type(s) for *: ‘dict’ and ‘DeviceArray’. It seems to happen because an assumption in Numpyro’s SVI class is violated by having a model containing parameters, whereas in pyro this works.

Is there a straightforward way to use numpyro to compute MLE (and MAP) or is this unsupported?

i think your error may be unrelated to mle/map. i suggest you post a code snippet.

for example, here’s an example of a map estimate

Thank you, there is complete code to reproduce the problem I see, adapted from the tutorial I linked above.

import numpyro.distributions.constraints as constraints
import numpyro as pyro
import numpy as np
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from jax import random

data = np.zeros(10)
data[0:6] = 1.0
def model_mle(data):
    # note that we need to include the interval constraint;
    # in original_model() this constraint appears implicitly in
    # the support of the Beta distribution.
    f = pyro.param("latent_fairness", np.array(0.5),
                   constraint=constraints.unit_interval)
    print("f:",f)
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Bernoulli(f), obs=data)
pyro.render_model(model_mle, model_args=(data,), render_distributions=True, render_params=True)
def guide_mle(data):
    pass

def train(model, guide, lr=0.005, n_steps=201):
    #pyro.clear_param_store()
    
    adam_params = {"lr": lr}
    adam = pyro.optim.Adam(adam_params)
    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    rng_key = random.PRNGKey(0)
    svi_state = svi.init(rng_key,data)

    for step in range(n_steps):
        svi_state,loss = svi.update(svi_state,data)
        if step % 50 == 0:
            print('[iter {}]  loss: {:.4f}'.format(step, loss))
train(model_mle, guide_mle)

edit: formatting mistake

PS: I installed the default numpyro using pip

Maybe you need to set lr=0.1 rather than a dict? I’m not sure

wow. I wish I had asked earlier. for the benefit of anybody reading this, the required change was:

...
def train(model, guide, lr=0.1, n_steps=201):
    #pyro.clear_param_store()
    
    #adam_params = {"lr": lr}
    adam = pyro.optim.Adam(step_size=lr)#adam_params)
    svi = SVI(model, guide, adam, loss=Trace_ELBO())
    ...

I would recommend using optax (you can search for some examples here and there in numpyro) rather than numpyro.optim, which depends on old jax optimizers.