Wrap model, and possibly other inference tools such as guide in class object

I am trying to write a class to do simple linear regression with numpyro. I think I am missing some constructor or an obvious python syntax. The code attached below through the following Assertion Error:

AssertionError                            Traceback (most recent call last)
/tmp/ipykernel_7172/331421733.py in <module>
      1 # SVI for model without subsample
      2 model_ini = NumpyroModel(loc=10)
----> 3 guide = autoguide.AutoNormal(model_ini.model(), init_loc_fn=init_to_sample)
      4 optimizer = numpyro.optim.Adam(step_size=1e-3)
      5 svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

/tmp/ipykernel_7172/1090656190.py in model(self, x, y)
      5     def model(self, x=None, y=None):
      6         loc = self.loc
----> 7         sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
      8         intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
      9         w = numpyro.sample("w", dist.Normal(loc=0, scale=20))


    176     # if there are no active Messengers, we just draw a sample and return it as expected:
    177     if not _PYRO_STACK:
--> 178         return fn(rng_key=rng_key, sample_shape=sample_shape)
    179 
    180     if obs_mask is not None:

    302         if sample_intermediates:
    303             return self.sample_with_intermediates(key, *args, **kwargs)
--> 304         return self.sample(key, *args, **kwargs)
    305 
    306     def to_event(self, reinterpreted_batch_ndims=None):

    369 
    370     def sample(self, key, sample_shape=()):
--> 371         assert is_prng_key(key)
    372         return jnp.abs(self._cauchy.sample(key, sample_shape))

Here is my code. Thanks in advance.

import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, autoguide, init_to_sample, NUTS, MCMC
from numpyro.infer import Trace_ELBO
numpyro.set_platform('cpu')
rng_key, rng_key_predict = random.split(random.PRNGKey(0))

size = 200
true_intercept = 1
true_slope = 2

x = jnp.linspace(0, 1, size).reshape(-1,1)
# y = intercept + w*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + random.normal(rng_key, shape=(size,1))

class NumpyroModel(object):
    def __init__(self, loc):
        self.loc = loc
    
    def model(self, x=None, y=None):
        loc = self.loc
        sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
        intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
        w = numpyro.sample("w", dist.Normal(loc=0, scale=20))  
        with numpyro.plate("data", x.shape[0]):
            numpyro.sample("obs", dist.Normal(intercept+w*x, scale=sigma).to_event(1), obs=y)

model_instance = NumpyroModel(loc=10)
guide = autoguide.AutoNormal(model_instance.model(), init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e-3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 10000, x, y)

should be

guide = autoguide.AutoNormal(model_instance.model, init_loc_fn=init_to_sample)

(there may be other issues)

1 Like

Thanks for now your solution did the trick.