Hi all, I’ve got a big model but this is a small subset in which I can reproduce an issue I’m running into. The sub-model has a single child node which is Beta-distributed to be correlated with the min of n
Beta-distributed parents.
parents_i ~ Beta(a, b)
concentration1 = 10 * min(parents)
concentration0 = 10 * (1 - min(parents))
child ~ Beta(concentration1, concentration0)
Running for 10k SVI steps with (a, b) = (10, 1)
, I’m able to recover that prior from the learned parameters of the parents, but the child’s parameters converge elsewhere, even though I “cheated” and initialized all the parameters to (a, b)
.
con1 con2 mean
11.61 1.10 0.91 parent_0
10.24 1.14 0.90 parent_1
10.34 0.99 0.91 parent_2
11.65 1.03 0.92 parent_3
11.35 1.04 0.92 parent_4
10.96 1.14 0.91 parent_5
8.33 1.93 0.81 child
If the prior has a concentration less than one, like (a, b) = (1, .33)
, the parents’ learned parameters are still pretty close to the prior mean of 0.75, though the actual concentrations are off, and the learned child parameters get even worse.
con1 con2 mean
1.48 0.43 0.77 parent_0
1.27 0.45 0.74 parent_1
1.39 0.44 0.76 parent_2
1.67 0.41 0.80 parent_3
1.59 0.41 0.79 parent_4
1.45 0.46 0.76 parent_5
4.11 5.93 0.41 child
Is there something screwy about the torch.min
gradient? I’ve seen discussions about torch min’s gradient being deterministic, but I’m unclear if that applies to anything here. And if I try a logsumexp-based smoothmin, it has the same issues if I parameterize it to be very close to a min approximation with alpha = -50
, but is fine with the smoothing ramped up at alpha = -1
. Also, it works fine if I just use torch.mean()
.
I would very much like to be able to use a min-based combination function, as the model represents a real world “loser-takes-all” scenario.
MWE
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
from pyro.distributions import constraints
import torch
from tqdm import tqdm
num_parents = 6
beta_prior = [10., 1.]
# beta_prior = [1., .33]
def model():
parent_values = torch.stack([
pyro.sample(f'parent_{i}', dist.Beta(torch.tensor([beta_prior[0]]), torch.tensor([beta_prior[1]]), ))
for i in range(num_parents)
])
combined_values = parent_values.min()
pyro.sample('child', dist.Beta(10 * combined_values, 10 * (1 - combined_values)))
def guide():
for i in range(num_parents):
concentration1 = pyro.param(f'concentration1_parent_{i}', torch.tensor([beta_prior[0]]), constraint=constraints.positive)
concentration0 = pyro.param(f'concentration0_parent_{i}', torch.tensor([beta_prior[1]]), constraint=constraints.positive)
pyro.sample(f'parent_{i}', dist.Beta(concentration1, concentration0))
concentration1 = pyro.param('concentration1_child', torch.tensor([beta_prior[0]]), constraint=constraints.positive)
concentration0 = pyro.param('concentration0_child', torch.tensor([beta_prior[1]]), constraint=constraints.positive)
pyro.sample('child', dist.Beta(concentration1, concentration0))
def main():
svi = pyro.infer.SVI(
model,
guide,
pyro.optim.Adam({"lr": 0.005, "betas": (0.95, 0.999)}),
loss=pyro.infer.Trace_ELBO(),
)
for _ in tqdm(range(10000)):
svi.step()
param_store = pyro.get_param_store()
node_names = [f'parent_{i}' for i in range(num_parents)] + ['child']
print('con1\tcon2\tmean')
for node_name in node_names:
concentration1 = float(param_store[f'concentration1_{node_name}'])
concentration0 = float(param_store[f'concentration0_{node_name}'])
print(f'{concentration1:.2f}\t{concentration0:.2f}\t{concentration1/(concentration1+concentration0):.2f}\t{node_name}')
if __name__ == '__main__':
main()