Need help with batching when using variational inference

Hello, I am trying to use variational inference in numpyro to estimate parameters of a neural network. I am currently facing problems when I am trying to apply batching/ subsampling my data and using auto continuous guides. I have shared part of my code below (the function that represents the model). This should show how I am applying the subsampling using numpyro.plate,

def integrate(self, feature_arrays, theta=None, labels={}, rng_key=None, batch_percentage=40):
        rng_key = rng_key or self.rng_key
        _theta = self.sample_theta() if theta is None else theta
        full_batch = feature_arrays["c0"].shape[0]
        batch_size = full_batch if (batch_percentage is None)) else int(batch_percentage / 100 * full_batch) + 1

        with numpyro.plate("batch",full_batch , subsample_size=batch_size, dim=-jnp.ndim(feature_arrays["c0"])) as ind:
            batch_features = {k: v[ind] for k, v in feature_arrays.items()}
            batch_labels = {k: v[ind] for k, v in labels.items()}
            
            output = self._integrate(_theta, batch_features)
            x = output["state"]
            y_obs, y_std = self.get_observations_and_noise(x, batch_labels, _theta)
            
            y = numpyro.sample(
                "y", dist.Normal(x, y_std), obs=y_obs, obs_mask=batch_labels.get("svi_obs_mask"), rng_key=rng_key
            )

        numpyro.deterministic("time", output["time"])
        output["y"] = y
        return output
  1. The batching/ subsampling should be done along the first dimension of the y_obs. which is set via the ‘dim’ argument in the plate statement.
  2. The ‘ind’ represents the tensor of indices which I use to filter my features and labels.
  3. Then I call the integrate function (which solves differential equations) to get output of model, and
  4. Finally set the observed state using the numpyro.sample statement within the with statement.

The problem is if I set the subsample size equal to full batch size (i.e. set batch_percentage in the above code to None), everything works fine.

However, if subsample_size is less than full batch size i.e. batch_percentage is not None, then I get an error when trying to use auto continuous guides, which states that:

“Autocontinuous guides do not work with local latent variables”

This is due to the assertion that compares the entries in the “args” of the ordered dictionary prepared by the trace handler for the “batch” variable with the “y_unobserved”. Not sure if I could explain everything. Sorry, if its completely unclear.

In that case, could you at least guide me on how to use subsampling/ batching for observed data with dimension (a, b, c) where they entries are independent along the dimension a, b is the time dimension; c is component.

When using obs_mask, it creates a latent variable. It’s local. And AutoContinuous does not work with local latent variables, as mentioned in the error message.

Probably you want to use mask rather than obs_mask. If you want to use obs_mask, probably AutoNormal will work with your model.

1 Like