I’m working on an adapter version of the Sparse Gamma DEF model from the tutorials, and I have a question relating to the independence of sampling sites in the guide.
I have a sequence of sampling sites that should all be independent. If I understand correctly, however, the samplers will, given the current implementation, not be independent. I’m wondering how I could go about declaring the independence of all sampling sites in the guide without having to build plate in plate in plate etc. as is suggested here.
See reference guide code below:
def guide(self, x):
x_size = x.size(0)
# Get last parameter values
if use_last_params:
with open("parameters/last_params.json") as json_file:
last_params = json.load(json_file)
# helper for initializing variational parameters
def rand_tensor(shape, mean, sigma):
return mean * torch.ones(shape) + sigma * torch.randn(shape)
# define a helper function to sample z's for a single layer
def sample_zs(name, width, last_value=None):
# Sample parameters or use last value
if use_last_params & (last_value is not None):
p_z_q = pyro.param("p_z_q_%s" % name, last_value)
else:
p_z_q = pyro.param("p_z_q_%s" % name,
lambda: rand_tensor((x_size, width), self.z_mean_init, self.z_sigma_init))
p_z_q = torch.sigmoid(p_z_q)
# Sample Z's
pyro.sample("z_%s" % name, Bernoulli(p_z_q).to_event(1),
infer=dict(baseline={'use_decaying_avg_baseline': True}))
# define a helper function to sample w's for a single layer
def sample_ws(name, width, mean, last_value=None):
# Sample parameters
if use_last_params & (last_value is not None):
mean_w_q = pyro.param("mean_w_q_%s" % name, last_value)
else:
mean_w_q = pyro.param("mean_w_q_%s" % name,
lambda: rand_tensor(width, mean, self.w_sigma_init))
sigma_w_q = pyro.param("sigma_w_q_%s" % name,
lambda: rand_tensor(width, self.w_mean_init, self.w_sigma_init))
sigma_w_q = self.softplus(sigma_w_q)
# Sample weights
pyro.sample("w_%s" % name, Normal(mean_w_q, sigma_w_q))
# define a helper function to sample c's for a single layer
def sample_cs(name, width, mean, last_value=None):
# Sample parameters
if use_last_params & (last_value is not None):
mean_c_q = pyro.param("mean_c_q_%s" % name, last_value)
else:
mean_c_q = pyro.param("mean_c_q_%s" % name,
lambda: rand_tensor(width, mean, self.c_sigma_init))
sigma_c_q = pyro.param("sigma_c_q_%s" % name,
lambda: rand_tensor(width, self.c_mean_init, self.c_sigma_init))
sigma_c_q = self.softplus(sigma_c_q)]
# Sample weights
pyro.sample("c_%s" % name, Normal(mean_c_q, sigma_c_q))
# sample the global weights and the bias terms
with pyro.plate("w_top_plate", self.top_width * self.bottom_width):
if use_last_params:
sample_ws("top", self.top_width * self.bottom_width,
mean=self.w_mean_init,
last_value=torch.tensor(last_params['w_top']))
else:
sample_ws("top", self.top_width * self.bottom_width,
mean=self.w_mean_init)
with pyro.plate("w_bottom_plate", self.bottom_width * self.data_size):
if use_last_params:
sample_ws("bottom", self.bottom_width * self.data_size,
mean=self.w_mean_init,
last_value=torch.tensor(last_params['w_bottom']))
else:
sample_ws("bottom", self.bottom_width * self.data_size,
mean=self.w_mean_init)
with pyro.plate("c_bottom_plate", self.bottom_width):
if use_last_params:
sample_cs("bottom", self.bottom_width,
mean=self.c_mean_init,
last_value=torch.tensor(last_params['c_bottom']))
else:
sample_cs("bottom", self.bottom_width,
mean=self.c_mean_init)
with pyro.plate("c_x_plate", self.data_size):
if use_last_params:
sample_cs("x", self.data_size,
mean=self.c_mean_init,
last_value=torch.tensor(last_params['c_x']))
else:
sample_cs("x", self.data_size,
mean=self.c_mean_init)
# sample the local latent random variables
with pyro.plate("data", x_size):
if use_last_params:
sample_zs("top", self.top_width,
last_value=torch.tensor(last_params['p_z_top']))
sample_zs("bottom", self.bottom_width,
last_value=torch.tensor(last_params['p_z_bottom']))
else:
sample_zs("top", self.top_width)
sample_zs("bottom", self.bottom_width)