NaN error using numpyro.factor

Hi,

I am using numpyro.factor in my guide as follows:

def guide(data):
      x_loc = numpyro.param("x_loc",lambda rng_key: jax.random.uniform(rng_key, shape=(30,)))

      numpyro.factor("penalty_1", jax.numpy.mean(x_loc ** (2)))
      numpyro.factor("penalty_2", repulsion_regularizer(x_loc))

      x_scale = numpyro.param("x_scale", jax.numpy.ones((30,)))
      x = numpyro.sample("x", dist.Normal(x_loc, x_scale).to_event(1))

When doing SVI, I get NaN’s for x_loc and x_scale if I include penalty_2 in the guide above.

I define repulsion_regularizer as follows:

def repulsion_regularizer(x):
      x = x.reshape(-1, 3)
      dist_mat = pdist_squareform(x=x, y=x)
      return dist_mat.sum()

with,

def euclidean_distance(x: np.array, y: np.array) -> float:
    """Taken from `jax-kern` module"""
    return np.sqrt(np.sum((x - y) ** 2))

def distmat(func: Callable, x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """Taken from `jax-kern` module"""
    return jax.vmap(lambda x1: jax.vmap(lambda y1: func(x1, y1))(y))(x)

def pdist_squareform(x: np.ndarray, y: np.ndarray) -> np.ndarray:
    """Taken from `jax-kern` module"""
    return distmat(euclidean_distance, x, y)

I need some help figuring out what I am doing wrong here.

Thank you,
Atharva

for one thing you shouldn’t be using numpy ops in your guide (np → jnp).

possibly your sqrt needs to be regularized or clamped to be positive or the like, e.g. sqrt(... + 1.0e-8)

1 Like

Adding 1.0e-8 worked. Thanks!