Pyro -> numpyro for gaussian inference

I’m transitioning some code from pyro to numpyro and I’ve run into an issue with the sqrt function. Below are 2 simple models doing inference of a normal distribution on some data. Models have been fit using MCMC. The pyro code (modelNormal) works perfectly, the numpyro code (modelNormalNP) gives me an error on the last line torch.sqrt(sigma). If I remove the sqrt the numpyro code runs w/o issue. I tried numpy.sqrt and still got an error:

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part

Any help is welcome :slight_smile:

 import pyro
 import typing
 import numpy
 import numpyro
 import numpyro.distributions as dist
 import pyro.distributions as dist_pyro
 import torch
 def modelNormalNP(goal, params, validate_args=True):
     sigma = numpyro.sample("sigma", dist.InverseGamma(0.1, 10, validate_args=validate_args))
     sigma += 1e-10
     mu = numpyro.sample("mu", dist.Normal(0.0, 1.0, validate_args=validate_args))
     with numpyro.plate("samples", len(goal)):
         numpyro.sample("goal", dist.Normal(mu, torch.sqrt(sigma), validate_args=validate_args),
 def modelNormal(goal, params, validate_args=True):
     sigma = pyro.sample("sigma", dist_pyro.InverseGamma(0.1, 10, validate_args=validate_args))
     sigma += 1e-10
     mu = pyro.sample("mu", dist_pyro.Normal(0.0, 1.0, validate_args=validate_args))
     with pyro.plate("samples", len(goal)):
         pyro.sample("goal", dist_pyro.Normal(mu, torch.sqrt(sigma), validate_args=validate_args),

you can’t arbitrarily mix pyro and numpyro. pyro uses pytorch. numpyro uses jax. jax and pytorch are incompatible. you need to choose between {pyro, pytorch} and {numpyro, jax}

@martinjankowiak Thanks for the reply. The non-mixing I had understood. the 2 models uniquely call pyro or numpyro. What I haven’t understood is how to call the sqrt function on an inferred variable in numpyro. In pyro it works without issue.


see jax docs

Your problem is here for numpyro:

When using numpyro, you need to use the jax or numpy equivalent methods:

import jax.numpy as jnp


         numpyro.sample("goal", dist.Normal(mu, jnp.sqrt(sigma), validate_args=validate_args),

thx works!