Need help on modeling

I have to model the following situation:

  1. 10 samples are sampled from two 2D multivariate normal distribution with different mean(A_true, B_true) and variance each. We now have two cluster of 10 dots. we will call them A and B

  2. first dot in A, A1 and first dot in B, B1 is interpreted as a path from A to B. we get total of 10 paths: A1→B1, A2→B2… A10→B10.

  3. The observer observes each path with camera noise. so the observer learns the following 10 vector set: {(A1,B1)+cn1, (A2,B2)+cn2…, (A10,B10)+cn10} note that A1c and B1c gets the same camera noise,

  4. In short, the observer observes {A1+cn1, A2+cn2, …} and {B1+cn1, B2+cn2, …}

  5. the camera noise is large so recovering B_true just from the mean of {B1+cn1, B2+cn2, …} is hard. but since the observer learned vector from A1+cn1 to B1+cn1, if A_assume is given, the observer can infer B_true out of it.

  6. A_assume is a single sample from same 2D multivariate normal distribution as A1…A10.

  7. given A_assume, the observer should infer B_true.

The model should accept the following 3 variables: {A1+cn1, A2+cn2, …}, {B1+cn1, B2+cn2, …}, A_assume
I want to avoid situation where we explicitly cancel out camera noise by computing (A1+cn1) - (B1+cn1) in the code. Instead I want model to add same camera noise and utilize that fact.
and infer out B_true.

Tried a lot of code, but I cannot get it to work ;-;
Any help or guide will be greatly appreciated. Thanks.

The following is the model I wrote, but with the warning
RuntimeWarning: trying to observe a value outside of inference at A_assume_obs

def model(noisy_A, noisy_B, A_assume_obs):

    A_mean = pyro.sample("A_mean", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    A_cov = pyro.sample("A_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    B_mean = pyro.sample("B_mean", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    B_cov = pyro.sample("B_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    cn_mean = pyro.sample("cn_mean", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    cn_cov = pyro.sample("cn_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    observation_noise = pyro.param("observation_noise", lambda: torch.tensor(0.000005) * torch.eye(2), constraint=constraints.positive)
    
    A_assume = pyro.sample("A_assume_obs", dist.MultivariateNormal(A_mean, observation_noise), obs=A_assume_obs)
    
    with pyro.plate("all_samples", noisy_A.shape[0]):
        A = pyro.sample("A", dist.MultivariateNormal(A_mean, A_cov))
        B = pyro.sample("B", dist.MultivariateNormal(B_mean, B_cov))
        cn = pyro.sample("cn", dist.MultivariateNormal(cn_mean, cn_cov))

        Ac = pyro.sample("Ac", dist.MultivariateNormal(A+cn, observation_noise), obs=noisy_A)
        Bc = pyro.sample("Bc", dist.MultivariateNormal(B+cn, observation_noise), obs=noisy_B)

Welp, I solved it😅

def model(noisy_A, noisy_B, A_assume_obs):
    observation_noise = torch.tensor(0.0001) * torch.eye(2)
    # observation_noise = pyro.param("observation_noise", lambda: torch.tensor(0.0001) * torch.eye(2), constraint=constraints.positive)

    A_mean = pyro.sample("A_mean", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)), obs=A_assume_obs)
    A_cov = pyro.sample("A_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    B_mean = pyro.sample("B_mean", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    B_cov = pyro.sample("B_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    # cn_mean = pyro.sample("cn_mean", dist.MultivariateNormal(torch.zeros(2), torch.eye(2)))
    # cn_cov = pyro.sample("cn_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)
    cn_mean = torch.tensor([0., 0.])
    cn_cov = torch.eye(2)*10

    with pyro.plate("data", len(noisy_A)):
        A = pyro.sample("A", dist.MultivariateNormal(A_mean, A_cov))
        B = pyro.sample("B", dist.MultivariateNormal(B_mean, B_cov))
        cn = pyro.sample("cn", dist.MultivariateNormal(cn_mean, cn_cov))

        Acn = pyro.sample("Acn", dist.MultivariateNormal(A+cn, observation_noise), obs=noisy_A)
        Bcn = pyro.sample("Bcn", dist.MultivariateNormal(B+cn, observation_noise), obs=noisy_B)



def guide(noisy_A, noisy_B, A_assume_obs):

    A_loc = pyro.param("A_loc", lambda: torch.tensor([0.0, 0.0]), constraint=constraints.real)
    # A_mean = pyro.sample("A_mean", dist.Delta(A_loc).to_event(1))
    A_cov = pyro.sample("A_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    B_loc = pyro.param("B_loc", lambda: torch.tensor([0.0, 0.0]), constraint=constraints.real)
    B_mean = pyro.sample("B_mean", dist.Delta(B_loc).to_event(1))
    B_cov = pyro.sample("B_cov", dist.InverseGamma(3.0, 2.0)) * torch.eye(2)

    cn_mean = pyro.param("cn_mean", torch.tensor([0., 0.]))
    cn_cov = pyro.param("cn_cov", torch.eye(2), constraint=dist.constraints.positive)

    with pyro.plate("data", len(noisy_A)):
        pyro.sample("cn", dist.MultivariateNormal(cn_mean, cn_cov))
        # pyro.sample("A", dist.MultivariateNormal(A_loc, A_cov))
        pyro.sample("B", dist.MultivariateNormal(B_loc, B_cov))```