How to create sequential independence

I’m working on an adapter version of the Sparse Gamma DEF model from the tutorials, and I have a question relating to the independence of sampling sites in the guide.

I have a sequence of sampling sites that should all be independent. If I understand correctly, however, the samplers will, given the current implementation, not be independent. I’m wondering how I could go about declaring the independence of all sampling sites in the guide without having to build plate in plate in plate etc. as is suggested here.

See reference guide code below:

def guide(self, x):
    x_size = x.size(0)

    # Get last parameter values
    if use_last_params:
        with open("parameters/last_params.json") as json_file:
            last_params = json.load(json_file)

    # helper for initializing variational parameters
    def rand_tensor(shape, mean, sigma):
        return mean * torch.ones(shape) + sigma * torch.randn(shape)

    # define a helper function to sample z's for a single layer
    def sample_zs(name, width, last_value=None):
        # Sample parameters or use last value
        if use_last_params & (last_value is not None):
            p_z_q = pyro.param("p_z_q_%s" % name, last_value)
        else:
            p_z_q = pyro.param("p_z_q_%s" % name,
                           lambda: rand_tensor((x_size, width), self.z_mean_init, self.z_sigma_init))
        p_z_q = torch.sigmoid(p_z_q)

        # Sample Z's
        pyro.sample("z_%s" % name, Bernoulli(p_z_q).to_event(1),
                    infer=dict(baseline={'use_decaying_avg_baseline': True}))

    # define a helper function to sample w's for a single layer
    def sample_ws(name, width, mean, last_value=None):
        # Sample parameters
        if use_last_params & (last_value is not None):
            mean_w_q = pyro.param("mean_w_q_%s" % name, last_value)
        else:
            mean_w_q = pyro.param("mean_w_q_%s" % name,
                                  lambda: rand_tensor(width, mean, self.w_sigma_init))
        sigma_w_q = pyro.param("sigma_w_q_%s" % name,
                               lambda: rand_tensor(width, self.w_mean_init, self.w_sigma_init))
        sigma_w_q = self.softplus(sigma_w_q)

        # Sample weights
        pyro.sample("w_%s" % name, Normal(mean_w_q, sigma_w_q))

    # define a helper function to sample c's for a single layer
    def sample_cs(name, width, mean, last_value=None):
        # Sample parameters
        if use_last_params & (last_value is not None):
            mean_c_q = pyro.param("mean_c_q_%s" % name, last_value)
        else:
            mean_c_q = pyro.param("mean_c_q_%s" % name,
                                  lambda: rand_tensor(width, mean, self.c_sigma_init))
        sigma_c_q = pyro.param("sigma_c_q_%s" % name,
                               lambda: rand_tensor(width, self.c_mean_init, self.c_sigma_init))
        sigma_c_q = self.softplus(sigma_c_q)]

        # Sample weights
        pyro.sample("c_%s" % name, Normal(mean_c_q, sigma_c_q))

    # sample the global weights and the bias terms
    with pyro.plate("w_top_plate", self.top_width * self.bottom_width):
        if use_last_params:
            sample_ws("top", self.top_width * self.bottom_width,
                      mean=self.w_mean_init,
                      last_value=torch.tensor(last_params['w_top']))
        else:
            sample_ws("top", self.top_width * self.bottom_width,
                      mean=self.w_mean_init)
    with pyro.plate("w_bottom_plate", self.bottom_width * self.data_size):
        if use_last_params:
            sample_ws("bottom", self.bottom_width * self.data_size,
                      mean=self.w_mean_init,
                      last_value=torch.tensor(last_params['w_bottom']))
        else:
            sample_ws("bottom", self.bottom_width * self.data_size,
                      mean=self.w_mean_init)

    with pyro.plate("c_bottom_plate", self.bottom_width):
        if use_last_params:
            sample_cs("bottom", self.bottom_width,
                      mean=self.c_mean_init,
                      last_value=torch.tensor(last_params['c_bottom']))
        else:
            sample_cs("bottom", self.bottom_width,
                      mean=self.c_mean_init)
    with pyro.plate("c_x_plate", self.data_size):
        if use_last_params:
            sample_cs("x", self.data_size,
                      mean=self.c_mean_init,
                      last_value=torch.tensor(last_params['c_x']))
        else:
            sample_cs("x", self.data_size,
                      mean=self.c_mean_init)

    # sample the local latent random variables
    with pyro.plate("data", x_size):
        if use_last_params:
            sample_zs("top", self.top_width,
                      last_value=torch.tensor(last_params['p_z_top']))
            sample_zs("bottom", self.bottom_width,
                      last_value=torch.tensor(last_params['p_z_bottom']))
        else:
            sample_zs("top", self.top_width)
            sample_zs("bottom", self.bottom_width)

You could use the suggestion in the topic you linked to for the final two calls to sample_zs, but note that additional independence information is automatically inferred from temporal order and when constructing reparameterized gradient estimators, so except for those two sample sites I don’t see any independence left unexploited in your guide.

Thanks for the help @eb8680_2 . I’ll add the structure to the sample_zs calls.

For my understanding; sequential plates are thus inferred to be independent automatically in Pyro, that is:

with pyro.plate("a", x)
    pyro.sample("var1", dist1) 
with pyro.plate("b", y)
    pyro.sample("var2", dist2)

var1 and var2 are regarded as independent in this case?

No, that’s incorrect - Pyro can’t tell from plates alone whether var2 depends on var1. Plates are intentionally conservative in the independence they assert, since we’d rather be conservative and correct than give incorrect independence information to an inference algorithm. Most of our inference algorithms can infer additional independence information automatically, so it’s not worth trying too hard to encode it manually with plate.