Creating a basic multivariable linear regression, why am I not getting the correct coefficients?

Learning to use pyro. I’ve created a multi-variable linear regression model. I generate some data where I know the coefficients, and see if my model can approximate them. But for my current model the means of the coefficient distributions are not close to that of the generated data. Small tweaks to prior distributions and their parameters don’t seem to help here.

``````import arviz as az
import numpy as np
import pyro
import pyro.distributions as dist
from pyro import infer
import torch

# Make my data
rng = np.random.default_rng(4)
n_columns = 3
n_rows = 1_000
x = rng.random((n_rows, n_columns))
coefficients_array = rng.integers(low=-3, high=3, size=n_columns)
# array([-1,  2, -3])
e = rng.normal(size=n_rows)
y = x @ coefficients_array + e

def model() -> None:
n_instances = n_rows
n_features = n_columns + 1
coefficients = pyro.sample(
"coefficients",
dist.Normal(torch.zeros(n_features), 3 * torch.ones(n_features)).to_event(1),
)
feature_tensor = torch.from_numpy(
np.hstack((x, np.ones((n_rows, 1)))).astype(np.float32)
)
mean_target = feature_tensor @ coefficients
target_sd = pyro.sample(
"e",
dist.HalfCauchy(2),
)
target_tensor = torch.from_numpy(
y.astype(np.float32)
)

with pyro.plate("data", n_instances):
pyro.sample(
"y",
dist.Normal(mean_target, target_sd),
obs=target_tensor,
)

pyro.set_rng_seed(42)
nuts_kernel = infer.NUTS(
model,
jit_compile=True,
max_tree_depth=5,
)
mcmc = infer.MCMC(
nuts_kernel,
num_samples=4_000,
warmup_steps=1_000,
num_chains=1,
)
mcmc.run()
posterior_samples = mcmc.get_samples()
inference_data = az.from_pyro(
mcmc
)
inference_data["posterior"].mean(dim=["chain", "draw"])
# Data variables:
# coefficients (coefficients_dim_0) float32  -0.9336 0.5098 -0.9214 -0.2346
# These don't match! expected -1, 2, -3, and 0

``````

Ok this seems to be a NUTS configuration issue on my end. If I remove both

``````    adapt_step_size=False,
from `NUTS`, then I get closer to expected means.