Using Pyro for data assimilation in combustion

Hey everyone,

I hope you’re doing well.
I will be concise in this message.

I’m trying to use Pyro for data assimilation. Originally I have a GPR model trained with CFD data using the openmeasure library. Then I want to use Pyro to optimise my GPR paramters, with experimental data as the observation. I have 8 experimental images that I use. In the Pyro model, I call the GPR for prediction of my field and then compare it with the experimental data.

but I don’t know whether pyro is doing what I intend to do.

Can you please help me?

The code is as follows:

def gpr_predict_on_gpr_grid(T, XN2, Cmix, C2epsilon):“”“Return reconstructed OH* field on the native GPR grid, as TORCH tensor.”“”

condition as numpy (GPR expects numpy)

condition = np.array([T, XN2, Cmix, C2epsilon]).reshape(1, -1)
# GPR prediction (NumPy)
modal_pred, _ = gpr.predict(condition)     # numpy

# Reconstruction (NumPy)
spatial_field = gpr.reconstruct(modal_pred).flatten()

# Convert to TORCH for Pyro model
return torch.as_tensor(spatial_field, dtype=torch.float32, device=device)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def pyro_model_full(data_list):
    n_conditions = len(data_list)

    # Global parameters
    sigma = pyro.sample("sigma", dist.Uniform(0.1, 5.0))
    alpha = pyro.sample("alpha", dist.Uniform(1e10, 1e12))
    beta  = pyro.sample("beta",  dist.Uniform(-5.0, 5.0))
    #beta  = pyro.sample("beta", dist.Uniform(0.3, 3.0))

    Cmix      = pyro.sample("Cmix", dist.Uniform(0.01, 0.6))
    C2epsilon = pyro.sample("C2epsilon", dist.Uniform(1.5, 3.0))
    Cmix_val = Cmix.detach().item()
    C2epsilon_val = C2epsilon.detach().item()



    # ---------------------------------------------------------------
    # Build experimental and GPR intensity lists in a single loop
    # ---------------------------------------------------------------
    I_exp_list = []
    OH_list = []

    for data in data_list:
        T     = float(data["T"])
        XN2   = float(data["XN2"])
        X_exp = np.asarray(data["X_exp"], float)
        Y_exp = np.asarray(data["Y_exp"], float)
        I_exp = np.asarray(data["I_exp"], float)

        I_exp_list.append(I_exp.ravel())

        # raw GPR output (small magnitude)
        OH_gpr = gpr_predict_on_gpr_grid(T, XN2, Cmix_val, C2epsilon_val)

        # interpolate raw values
        Yg, Xg = np.meshgrid(Y_exp, X_exp, indexing="ij")

        OH_grid = griddata(
            (xyzi_pos[:, 0], xyzi_pos[:, 1]),
            OH_gpr,
            (Yg, Xg),
            method="linear",
            fill_value=0.0
        )

        OH_list.append(OH_grid.ravel())

    # ---------------------------------------------------------------
    # Vectorize into tensors
    # ---------------------------------------------------------------
    I_exp_tensor = torch.tensor(np.stack(I_exp_list), dtype=torch.float32, device=device)
    OH_tensor    = torch.tensor(np.stack(OH_list),    dtype=torch.float32, device=device)

    # predicted intensity (vectorized)
    I_pred = alpha * OH_tensor + beta
    # ---- power-law mapping ----
    # eps = 1e-20  # avoid 0**beta
    # # 1) force non-negative base (physically OH* ≥ 0 anyway)
    # OH_tensor = torch.clamp(OH_tensor, min=0.0)

    # # 2) compute power-law mapping
    # I_pred = alpha * torch.pow(OH_tensor + eps, beta)

    # # 3) just in case anything weird slipped through
    # I_pred = torch.nan_to_num(I_pred, nan=0.0, posinf=0.0, neginf=0.0)


    # ---------------------------------------------------------------
    # Likelihood across conditions
    # ---------------------------------------------------------------
    with pyro.plate("conditions", n_conditions):
        pyro.sample(
            "obs",
            dist.Normal(I_pred, sigma).to_event(1),
            obs=I_exp_tensor
        )


print("✓ Model ready.")



pyro.clear_param_store()
# pyro.set_rng_seed(42)
# torch.manual_seed(42)
# np.random.seed(42)

#guide = AutoNormal(lambda data: pyro_model_full(data))
#guide = AutoNormal(pyro_model_full)
guide = AutoMultivariateNormal(pyro_model_full)
#guide = AutoDiagonalNormal(pyro_model_full)

optimizer = Adam({"lr": 1e-3})
##optimizer = AdagradRMSProp({"lr": 1e-3, "eta": 1.0, "tolerance_grad": 1e-8})
optimizer = ClippedAdam({"lr": 5e-4, "clip_norm": 5.0})
#svi = SVI(lambda data: pyro_model_full(data), guide, optimizer, loss=Trace_ELBO())
svi = SVI(pyro_model_full, guide, optimizer, loss=Trace_ELBO())


num_steps = 8000
losses = []

print(f"Running {num_steps} iterations on {device}...")

for step in range(num_steps):
    loss = svi.step(all_exp_data)
    losses.append(loss)
    
    if step % 500 == 0:
        print(f"  Step {step:>4}: ELBO = {loss:>12.4f}")
    # if step == 10:
    #     print("\nParameters in param store:", list(pyro.get_param_store().keys()))


print("✓ Inference complete!")

plt.figure(); plt.plot(losses); plt.title("SVI convergence"); plt.yscale("log")
plt.xlabel("iter"); plt.ylabel("ELBO loss"); plt.show()