Handling missing features

I know there are tutorials/examples on imputing entirely missing observations, but are there any examples or hints as to how to handle missing features? For example, if the data looks like:

  [[1, .8, 1],
   [1, .5, .8],
   [.8, .8, 1],
   [.2, nan, 0.2],

and the model is something like:

def model(data):
    mus = ny.sample("mus", Normal(0, 1).expand((3,)))
    chol = ny.sample("chol", LKJCholesky(3, concentration=1))
    ny.sample("ys", MultivariateNormal(mus, scale_tril=chol), obs=data)

it would be nice to be able to use the remaining features that are present in the fourth row to improve the estimates of mu for the first and third column and to improve the estimates of their covariances.

(Note: I don’t particularly care about what the missing values themselves are. So I don’t necessarily need to treat the missing values as latent variables and impute them as the examples/tutorials usually do. I just care about improving my estimates of mus and chol.)

One approach that seems sort of reasonable to me is to do something like:

      ys = ny.sample("ys", MultivariateNormal(mus, scale_tril=chol))
      present_idx = np.nonzero(~np.isnan(data))
      ny.deterministic("observed_ys", ys.at[present_idx].get(), obs=data.at[present_idx].get())

But deterministic doesn’t support obs. Is this actually a reasonable approach? Is there some way to emulate this approach in numpyro given the absence of obs support in deterministic?

Actually, I think I’ve figured out something that mostly works.

def model3(data: "np.ndarray[float]") -> None:
    mus = ny.sample("mus", dist.Normal(0, 1).expand((2,)))
    def inner(_, row):
        present_idx = jnp.nonzero(jnp.invert(jnp.isnan(row)), size=2)
        row_obs = row[present_idx]
        mus_obs = mus[present_idx]
        scale_tril_obs = jnp.squeeze(jnp.eye(2)[:,present_idx][present_idx,:])
        ny.sample("xs", dist.MultivariateNormal(loc=mus_obs, scale_tril=scale_tril_obs), obs=row_obs)
        return None, None
    scan(inner, None, data))

I know this isn’t quite right, because of the static size for nonzero and the fill_value behavior, but it’s the closest I’ve gotten and seems to generally have the desired behavior (i.e. allows nan in some features while still learning from non-nan features). It seems like where might be a cleaner way to do this?