Hey everyone,
I hope you’re doing well.
I will be concise in this message.
I’m trying to use Pyro for data assimilation. Originally I have a GPR model trained with CFD data using the openmeasure library. Then I want to use Pyro to optimise my GPR paramters, with experimental data as the observation. I have 8 experimental images that I use. In the Pyro model, I call the GPR for prediction of my field and then compare it with the experimental data.
but I don’t know whether pyro is doing what I intend to do.
Can you please help me?
The code is as follows:
def gpr_predict_on_gpr_grid(T, XN2, Cmix, C2epsilon):“”“Return reconstructed OH* field on the native GPR grid, as TORCH tensor.”“”
condition as numpy (GPR expects numpy)
condition = np.array([T, XN2, Cmix, C2epsilon]).reshape(1, -1)
# GPR prediction (NumPy)
modal_pred, _ = gpr.predict(condition) # numpy
# Reconstruction (NumPy)
spatial_field = gpr.reconstruct(modal_pred).flatten()
# Convert to TORCH for Pyro model
return torch.as_tensor(spatial_field, dtype=torch.float32, device=device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def pyro_model_full(data_list):
n_conditions = len(data_list)
# Global parameters
sigma = pyro.sample("sigma", dist.Uniform(0.1, 5.0))
alpha = pyro.sample("alpha", dist.Uniform(1e10, 1e12))
beta = pyro.sample("beta", dist.Uniform(-5.0, 5.0))
#beta = pyro.sample("beta", dist.Uniform(0.3, 3.0))
Cmix = pyro.sample("Cmix", dist.Uniform(0.01, 0.6))
C2epsilon = pyro.sample("C2epsilon", dist.Uniform(1.5, 3.0))
Cmix_val = Cmix.detach().item()
C2epsilon_val = C2epsilon.detach().item()
# ---------------------------------------------------------------
# Build experimental and GPR intensity lists in a single loop
# ---------------------------------------------------------------
I_exp_list = []
OH_list = []
for data in data_list:
T = float(data["T"])
XN2 = float(data["XN2"])
X_exp = np.asarray(data["X_exp"], float)
Y_exp = np.asarray(data["Y_exp"], float)
I_exp = np.asarray(data["I_exp"], float)
I_exp_list.append(I_exp.ravel())
# raw GPR output (small magnitude)
OH_gpr = gpr_predict_on_gpr_grid(T, XN2, Cmix_val, C2epsilon_val)
# interpolate raw values
Yg, Xg = np.meshgrid(Y_exp, X_exp, indexing="ij")
OH_grid = griddata(
(xyzi_pos[:, 0], xyzi_pos[:, 1]),
OH_gpr,
(Yg, Xg),
method="linear",
fill_value=0.0
)
OH_list.append(OH_grid.ravel())
# ---------------------------------------------------------------
# Vectorize into tensors
# ---------------------------------------------------------------
I_exp_tensor = torch.tensor(np.stack(I_exp_list), dtype=torch.float32, device=device)
OH_tensor = torch.tensor(np.stack(OH_list), dtype=torch.float32, device=device)
# predicted intensity (vectorized)
I_pred = alpha * OH_tensor + beta
# ---- power-law mapping ----
# eps = 1e-20 # avoid 0**beta
# # 1) force non-negative base (physically OH* ≥ 0 anyway)
# OH_tensor = torch.clamp(OH_tensor, min=0.0)
# # 2) compute power-law mapping
# I_pred = alpha * torch.pow(OH_tensor + eps, beta)
# # 3) just in case anything weird slipped through
# I_pred = torch.nan_to_num(I_pred, nan=0.0, posinf=0.0, neginf=0.0)
# ---------------------------------------------------------------
# Likelihood across conditions
# ---------------------------------------------------------------
with pyro.plate("conditions", n_conditions):
pyro.sample(
"obs",
dist.Normal(I_pred, sigma).to_event(1),
obs=I_exp_tensor
)
print("✓ Model ready.")
pyro.clear_param_store()
# pyro.set_rng_seed(42)
# torch.manual_seed(42)
# np.random.seed(42)
#guide = AutoNormal(lambda data: pyro_model_full(data))
#guide = AutoNormal(pyro_model_full)
guide = AutoMultivariateNormal(pyro_model_full)
#guide = AutoDiagonalNormal(pyro_model_full)
optimizer = Adam({"lr": 1e-3})
##optimizer = AdagradRMSProp({"lr": 1e-3, "eta": 1.0, "tolerance_grad": 1e-8})
optimizer = ClippedAdam({"lr": 5e-4, "clip_norm": 5.0})
#svi = SVI(lambda data: pyro_model_full(data), guide, optimizer, loss=Trace_ELBO())
svi = SVI(pyro_model_full, guide, optimizer, loss=Trace_ELBO())
num_steps = 8000
losses = []
print(f"Running {num_steps} iterations on {device}...")
for step in range(num_steps):
loss = svi.step(all_exp_data)
losses.append(loss)
if step % 500 == 0:
print(f" Step {step:>4}: ELBO = {loss:>12.4f}")
# if step == 10:
# print("\nParameters in param store:", list(pyro.get_param_store().keys()))
print("✓ Inference complete!")
plt.figure(); plt.plot(losses); plt.title("SVI convergence"); plt.yscale("log")
plt.xlabel("iter"); plt.ylabel("ELBO loss"); plt.show()