I am trying to write a class to do simple linear regression with numpyro
. I think I am missing some constructor or an obvious python syntax. The code attached below through the following Assertion Error:
AssertionError Traceback (most recent call last)
/tmp/ipykernel_7172/331421733.py in <module>
1 # SVI for model without subsample
2 model_ini = NumpyroModel(loc=10)
----> 3 guide = autoguide.AutoNormal(model_ini.model(), init_loc_fn=init_to_sample)
4 optimizer = numpyro.optim.Adam(step_size=1e-3)
5 svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
/tmp/ipykernel_7172/1090656190.py in model(self, x, y)
5 def model(self, x=None, y=None):
6 loc = self.loc
----> 7 sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
8 intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
9 w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
176 # if there are no active Messengers, we just draw a sample and return it as expected:
177 if not _PYRO_STACK:
--> 178 return fn(rng_key=rng_key, sample_shape=sample_shape)
179
180 if obs_mask is not None:
302 if sample_intermediates:
303 return self.sample_with_intermediates(key, *args, **kwargs)
--> 304 return self.sample(key, *args, **kwargs)
305
306 def to_event(self, reinterpreted_batch_ndims=None):
369
370 def sample(self, key, sample_shape=()):
--> 371 assert is_prng_key(key)
372 return jnp.abs(self._cauchy.sample(key, sample_shape))
Here is my code. Thanks in advance.
import numpy as np
import jax.numpy as jnp
from jax import random
import numpyro
import numpyro.distributions as dist
from numpyro.infer import Predictive, SVI, autoguide, init_to_sample, NUTS, MCMC
from numpyro.infer import Trace_ELBO
numpyro.set_platform('cpu')
rng_key, rng_key_predict = random.split(random.PRNGKey(0))
size = 200
true_intercept = 1
true_slope = 2
x = jnp.linspace(0, 1, size).reshape(-1,1)
# y = intercept + w*x
true_regression_line = true_intercept + true_slope * x
# add noise
y = true_regression_line + random.normal(rng_key, shape=(size,1))
class NumpyroModel(object):
def __init__(self, loc):
self.loc = loc
def model(self, x=None, y=None):
loc = self.loc
sigma = numpyro.sample("sigma", dist.HalfCauchy(scale=10))
intercept = numpyro.sample("Intercept", dist.Normal(loc=0, scale=20))
w = numpyro.sample("w", dist.Normal(loc=0, scale=20))
with numpyro.plate("data", x.shape[0]):
numpyro.sample("obs", dist.Normal(intercept+w*x, scale=sigma).to_event(1), obs=y)
model_instance = NumpyroModel(loc=10)
guide = autoguide.AutoNormal(model_instance.model(), init_loc_fn=init_to_sample)
optimizer = numpyro.optim.Adam(step_size=1e-3)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(rng_key, 10000, x, y)