I’m transitioning some code from pyro to numpyro and I’ve run into an issue with the sqrt function. Below are 2 simple models doing inference of a normal distribution on some data. Models have been fit using MCMC. The pyro code (modelNormal) works perfectly, the numpyro code (modelNormalNP) gives me an error on the last line torch.sqrt(sigma). If I remove the sqrt the numpyro code runs w/o issue. I tried numpy.sqrt and still got an error:
ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part
Any help is welcome
import pyro
import typing
import numpy
import numpyro
import numpyro.distributions as dist
import pyro.distributions as dist_pyro
import torch
def modelNormalNP(goal, params, validate_args=True):
sigma = numpyro.sample("sigma", dist.InverseGamma(0.1, 10, validate_args=validate_args))
sigma += 1e-10
mu = numpyro.sample("mu", dist.Normal(0.0, 1.0, validate_args=validate_args))
with numpyro.plate("samples", len(goal)):
numpyro.sample("goal", dist.Normal(mu, torch.sqrt(sigma), validate_args=validate_args),
obs=goal)
def modelNormal(goal, params, validate_args=True):
sigma = pyro.sample("sigma", dist_pyro.InverseGamma(0.1, 10, validate_args=validate_args))
sigma += 1e-10
mu = pyro.sample("mu", dist_pyro.Normal(0.0, 1.0, validate_args=validate_args))
with pyro.plate("samples", len(goal)):
pyro.sample("goal", dist_pyro.Normal(mu, torch.sqrt(sigma), validate_args=validate_args),
obs=goal)`