More adventures with poutine.nan_mask

Hi, back again with a nan_mask question. This time I’m getting unexpected SVI behavior from instances which have all of their descendant observations nan’d out.

Toy model is below. Everything is drawn from a Beta distribution. C tracks the mean of A and B; D tracks C.

The issue in the simplest case is when i=2, but all D_1j are observed as nan, e.g.
D = [[0.9, 0.9, 0.9, ..., 0.9], [nan, nan, nan, ... nan]]. In that case A, B_0, and C_0 have descendant observations at 0.9, but B_1 and C_1 only have nan descendant observations. C_1 still has some indirect connection to the non-nan observations via C_1 <- A -> C_0 -> D_0j.

B_1, however, is d-separated due to the v-structure and therefore I would expect SVI to just draw it from the prior. I’ve confirmed this is the case when A is removed from the model. But when A is present the prediction for B_1 clearly deviates from the prior mean of 0.75, as below. Note that this happens in the opposite direction if the prior is flipped to Beta(0.3, 1.0): then B_1 deviates below the prior mean of 0.25. And if the prior is uniform at Beta(1.0, 1.0) then B_1 stays around the expectation of 0.5…

A:  [0.8]
B:  [0.89, 0.89]
C:  [0.86, 0.84]

Any idea why? The real model is much larger and more complex, and it would be much better to not have to worry about filtering things out that may be missing all descendant observations, and instead just be able to assume any such predictions are drawn from the prior.

MWE

Note: making sure to properly handle nan observations as fixed in my previous post.

import pyro
import torch
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from typing import List

num_BC = 2
num_D = 100


def model(obs: torch.Tensor, nan_mask: torch.Tensor):

    A = pyro.sample('A', dist.Beta(1, .3))

    with pyro.plate('BC_plate', size=num_BC, dim=-1):
        B = pyro.sample('B', dist.Beta(1, .3))

        parent_values = torch.stack((B, A.expand(B.shape)))
        combined_parent_values = torch.mean(parent_values, dim=0)
        C = pyro.sample('C', dist.Beta(10 * combined_parent_values, 10 * (1 - combined_parent_values)))

        with pyro.plate('D_plate', size=num_D, dim=-2):
            C_expanded = C.expand(num_D, num_BC)
            with pyro.poutine.mask(mask=nan_mask):
                D = pyro.sample('D', dist.Beta(10 * C_expanded, 10 * (1 - C_expanded)), obs=obs)


def guide(obs: torch.Tensor, nan_mask: torch.Tensor):

    A_concentrations = get_concentration_params('A', 1)
    pyro.sample('A', dist.Beta(*A_concentrations))

    with pyro.plate('BC_plate', size=num_BC, dim=-1):
        B_concentrations = get_concentration_params('B', num_BC)
        pyro.sample('B', dist.Beta(*B_concentrations))

        C_concentrations = get_concentration_params('C', num_BC)
        pyro.sample('C', dist.Beta(*C_concentrations))


def get_concentration_params(site_name: str, size: int) -> List[pyro.param]:
    # Beta distribution concentration1 and concentration0 parameters
    return [
            pyro.param(f'{site_name}_concentration{i}', torch.ones(size), constraint=constraints.greater_than(0.1))
            for i in reversed(range(2))
        ]


def main():

    obs = torch.full((num_D, num_BC), 0.9)
    obs[:, -1] = float('nan')

    nan_mask = ~torch.isnan(obs)
    obs[torch.isnan(obs)] = 0.4242  # to avoid nan loss

    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adam({"lr": 0.1, "betas": (0.95, 0.999)}),
        loss=pyro.infer.Trace_ELBO(),
    )

    for _ in range(10000):
        svi.step(obs, nan_mask)

    param_store = pyro.get_param_store()

    for site_id in ['A', 'B', 'C']:
        concentration1 = param_store[f'{site_id}_concentration1']
        concentration0 = param_store[f'{site_id}_concentration0']
        beta_mean = concentration1 / (concentration1 + concentration0)
        print(f'{site_id}: ', [round(float(v), 2) for v in beta_mean])


if __name__ == '__main__':
    main()

I thought maybe pyro was doing something weird with the descendant of the v-structure being partially observed as a single sample site, but I just confirmed the same thing happens with for loops and separate sample sites instead of plates. So still something with the nanmask I guess.

MWE with no plates
import pyro
import torch
import pyro.distributions as dist
import torch.distributions.constraints as constraints
from typing import List

num_BC = 2
num_D = 100


def model(obs: torch.Tensor, nan_mask: torch.Tensor):

    A = pyro.sample('A', dist.Beta(1, .3))

    for i in range(num_BC):
        B = pyro.sample(f'B_{i}', dist.Beta(1, .3))

        combined_parent_values = (A + B) / 2
        C = pyro.sample(f'C_{i}', dist.Beta(10 * combined_parent_values, 10 * (1 - combined_parent_values)))

        for j in range(num_D):
            with pyro.poutine.mask(mask=nan_mask[j, i]):
                D = pyro.sample(f'D_{i}_{j}', dist.Beta(10 * C, 10 * (1 - C)), obs=obs[j, i])


def guide(obs: torch.Tensor, nan_mask: torch.Tensor):

    A_concentrations = get_concentration_params('A_0', 1)
    pyro.sample('A', dist.Beta(*A_concentrations))

    for i in range(num_BC):
        B_concentrations = get_concentration_params(f'B_{i}', 1)
        pyro.sample(f'B_{i}', dist.Beta(*B_concentrations))

        C_concentrations = get_concentration_params(f'C_{i}', 1)
        pyro.sample(f'C_{i}', dist.Beta(*C_concentrations))


def get_concentration_params(site_name: str, size: int) -> List[pyro.param]:
    # Beta distribution concentration1 and concentration0 parameters
    return [
            pyro.param(f'{site_name}_concentration{i}', torch.ones(size), constraint=constraints.greater_than(0.1))
            for i in reversed(range(2))
        ]


def main():

    obs = torch.full((num_D, num_BC), 0.9)
    obs[:, -1] = float('nan')

    nan_mask = ~torch.isnan(obs)
    obs[torch.isnan(obs)] = 0.4242  # to avoid nan loss

    svi = pyro.infer.SVI(
        model,
        guide,
        pyro.optim.Adam({"lr": 0.1, "betas": (0.95, 0.999)}),
        loss=pyro.infer.Trace_ELBO(),
    )

    import tqdm
    for _ in tqdm.tqdm(range(10000)):
        svi.step(obs, nan_mask)

    param_store = pyro.get_param_store()

    for site_id, num in {'A': 1, 'B': num_BC, 'C': num_BC}.items():
        beta_means = []
        for i in range(num):
            concentration1 = param_store[f'{site_id}_{i}_concentration1']
            concentration0 = param_store[f'{site_id}_{i}_concentration0']
            beta_means.append(concentration1 / (concentration1 + concentration0))
        print(f'{site_id}: ', [round(float(v), 2) for v in beta_means])


if __name__ == '__main__':
    main()

have you tried initializing everything at the prior? are you sure this isn’t an optimization problem (as opposed to an incorrect elbo/elbo gradient)?

1 Like

Aha, good suggestion. Though interestingly it turned out that setting the init to the prior didn’t matter, but it was the init magnitude being set too small that was the issue.

E.g. in the MWE above the init is set to Beta(1.0, 1.0), but if I set it to Beta(10.0, 10.0) then B_1 is drawn from the prior regardless of the prior setting.

Not sure why that init magnitude affects the optimization, but so be it. Thanks!