I’ve created the following Bayesian Regression code based on the Pyro tutorial.
def model(x_data, y_data):
"""
Model where I've conditioned using "obs"
"""
# weight, bias priors
w_prior = Normal(torch.zeros(1, 134), torch.ones(1, 134)).to_event(1)
b_prior = Normal(torch.tensor([[8.]]), torch.tensor([[1000.]])).to_event(1)
priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
lifted_module = pyro.random_module("module", regression_model, priors)
lifted_reg_model = lifted_module()
with pyro.plate("map", len(x_data)):
prediction_mean = lifted_reg_model(x_data)
pyro.sample("obs",
Normal(prediction_mean, scale),
obs = y_data)
return prediction_mean
def model_c(x_data, y_data):
""" Model where I've omitted obs so I use pyro.condition()"""
# weight, bias priors
w_prior = Normal(torch.zeros(1, 134), torch.ones(1, 134)).to_event(1)
b_prior = Normal(torch.tensor([[8.]]), torch.tensor([[1000.]])).to_event(1)
priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
scale = pyro.sample("sigma", Uniform(0., 10.))
lifted_module = pyro.random_module("module", regression_model, priors)
lifted_reg_model = lifted_module()
with pyro.plate("map", len(x_data)):
prediction_mean = lifted_reg_model(x_data)
pyro.sample("obs",
Normal(prediction_mean, scale))
return prediction_mean
from pyro.infer.autoguide import AutoDiagonalNormal
# initialize the autodiagonal with init_to_feasible instead of init_to_median
from pyro.infer.autoguide import init_to_feasible
I tested 2 things here:
- Calling AutoDiagonalNormal guide and SVI on the function
model
which uses obs to condition ony
. And then checked the parameters.
optim = Adam({"lr": 0.03})
guide = AutoDiagonalNormal(model, init_loc_fn = init_to_feasible)
svi = SVI(model, guide, optim, loss=Trace_ELBO(), num_samples=10000)
pyro.set_rng_seed(101)
num_iterations = 1000
def train():
pyro.clear_param_store()
for j in range(num_iterations):
loss = svi.step(X_and_z, Y_million.reshape(1, 3181))
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss/len(X_and_z)))
train()
This then gave me the following losses and parameters:
[iteration 0001] loss: 29.7135
[iteration 0101] loss: 4.6161
[iteration 0201] loss: 4.7610
[iteration 0301] loss: 4.0946
[iteration 0401] loss: 3.8583
[iteration 0501] loss: 3.8010
[iteration 0601] loss: 3.7266
[iteration 0701] loss: 3.6230
[iteration 0801] loss: 3.4929
[iteration 0901] loss: 3.3711
# With obs model
for name, value in pyro.get_param_store().items():
print(name, pyro.param(name), pyro.param(name).shape)
auto_loc tensor([ 1.5752, 1.9036, 1.7293, 1.2149, 1.3549, 1.4660, 1.5495, 1.5199,
1.3670, 1.7639, 2.0447, 1.5349, 1.7269, 1.6461, 1.3144, 1.7847,
1.5538, 2.0627, 1.3742, 1.6399, 1.2461, 1.1668, 2.4442, 1.8599,
1.3793, 0.9716, 1.4281, 1.3905, 1.1589, 1.5255, 1.2318, 1.1674,
1.2114, 1.4748, 1.8203, 1.2250, 0.9152, 1.5038, 1.1281, 1.6029,
1.4207, 2.0973, 1.5329, 1.4418, 1.5366, 1.3006, 1.3467, 1.4620,
2.1395, 1.5185, 1.5207, 1.2547, 1.5547, 1.5540, 1.4916, 1.3020,
1.5153, 1.1697, 1.4367, 0.9811, 1.1819, 1.2552, 1.5145, 1.8372,
1.2713, 1.4987, 1.5078, 1.2076, 1.5469, 1.1251, 2.3320, 1.1417,
1.1975, 1.7049, 1.1603, 1.4760, 1.2557, 1.3845, 1.5729, 1.1285,
1.2072, 1.1412, 1.7014, 1.7424, 1.0425, 1.2191, 1.7736, 2.1868,
1.6765, 1.5919, 1.3682, 1.4483, 2.5514, 2.1047, 1.4144, 1.6414,
1.5273, 1.1809, 1.8491, 1.2392, 1.4084, 1.1238, 2.6613, 1.4572,
1.2879, 1.2808, 1.7650, 1.5197, 1.3765, 2.4904, 1.4313, 1.3829,
1.6127, 1.5570, 1.3939, 1.4727, 1.5342, 1.3148, 1.5934, 1.3072,
1.7571, 2.2243, 2.2008, 1.3384, 1.2911, 1.2438, 1.4588, 1.8506,
1.4949, 1.3469, -0.0884, 0.0638, -0.0652, 0.1049, -0.0253, 11.4191],
requires_grad=True) torch.Size([136])
auto_scale tensor([0.3781, 0.8812, 0.9367, 0.8434, 0.8525, 0.8364, 0.8990, 0.8881, 0.8960,
0.8428, 0.8487, 0.9105, 0.8389, 0.8692, 0.9308, 0.9092, 0.7921, 0.8101,
0.7852, 0.8016, 0.8839, 0.8495, 0.8818, 0.8740, 0.7862, 0.9151, 0.8602,
0.9156, 0.9298, 0.7928, 0.9807, 0.9286, 0.8600, 0.8443, 0.8855, 0.8889,
0.8752, 0.8385, 0.8879, 0.8787, 0.8877, 0.8246, 0.8388, 0.7809, 0.8906,
0.8970, 0.8267, 0.8196, 0.8011, 0.8834, 0.8472, 0.9160, 0.8691, 0.8302,
0.8097, 0.8748, 0.9058, 0.7759, 0.8418, 0.8937, 0.9757, 0.8007, 0.9057,
0.8539, 0.7408, 0.8163, 0.8763, 0.8826, 0.8412, 0.8188, 0.7885, 0.8916,
0.7838, 0.9019, 0.8537, 0.8709, 0.8879, 0.8449, 0.8115, 0.8510, 0.9078,
0.8710, 0.8935, 0.8671, 0.8822, 0.8705, 0.8498, 0.8314, 0.9034, 0.9495,
0.8732, 0.9047, 0.7394, 0.7512, 0.9212, 0.8179, 0.7292, 0.8706, 0.9672,
0.8535, 0.8718, 0.8957, 0.7607, 0.8774, 0.8542, 0.8854, 0.8569, 0.7995,
0.8654, 0.7449, 0.8614, 0.8662, 0.9169, 0.9063, 0.9495, 0.8056, 0.9239,
0.8958, 0.8614, 0.8758, 0.8066, 0.8269, 0.8363, 0.8656, 0.8406, 0.8347,
0.8266, 0.8977, 0.8421, 0.8378, 0.2042, 0.1784, 0.1919, 0.1640, 0.1997,
0.3202], grad_fn=<AddBackward0>) torch.Size([136])
- I then used
model_c
to condition on observedy
usingpyro.condition()
and checked for the same:
optim = Adam({"lr": 0.03})
cond_model = pyro.condition(model_c, data = {"obs" : Y_million.reshape(3181)})
guide = AutoDiagonalNormal(cond_model, init_loc_fn = init_to_feasible)
svi = SVI(cond_model, guide, optim, loss=Trace_ELBO(), num_samples=10000)
pyro.set_rng_seed(101)
num_iterations = 1000
def train():
pyro.clear_param_store()
for j in range(num_iterations):
loss = svi.step(X_and_z, Y_million.reshape(1, 3181))
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss/len(X_and_z)))
train()
[iteration 0001] loss: 29.7135
[iteration 0101] loss: 4.6161
[iteration 0201] loss: 4.7610
[iteration 0301] loss: 4.0946
[iteration 0401] loss: 3.8583
[iteration 0501] loss: 3.8010
[iteration 0601] loss: 3.7266
[iteration 0701] loss: 3.6230
[iteration 0801] loss: 3.4929
[iteration 0901] loss: 3.3711
# With conditioned model
for name, value in pyro.get_param_store().items():
print(name, pyro.param(name), pyro.param(name).shape)
auto_loc tensor([ 1.6132, 2.7714, 2.5792, 2.0241, 2.1809, 2.3087, 2.2511, 2.3428,
2.2490, 2.7082, 3.0203, 2.4826, 2.7354, 2.4863, 2.0219, 2.6712,
2.4538, 2.9045, 2.2683, 2.5128, 2.2497, 2.0157, 3.3892, 2.6955,
2.2234, 1.6766, 2.3512, 2.1573, 1.9666, 2.4318, 2.2161, 1.9102,
2.0722, 2.4220, 2.7766, 1.8294, 1.6031, 2.2478, 1.9566, 2.3849,
2.2140, 2.9028, 2.5078, 2.3861, 2.3860, 2.0456, 2.1457, 2.4038,
3.0510, 2.2109, 2.3494, 1.9805, 2.3524, 2.3598, 2.3478, 2.2067,
2.1379, 1.9105, 2.2777, 1.8151, 2.0005, 2.2392, 2.3093, 2.6808,
2.0162, 2.4783, 2.4935, 2.1366, 2.4325, 1.9198, 3.2639, 1.9199,
2.0147, 2.5201, 1.9299, 2.3060, 2.0703, 2.1851, 2.4940, 1.9562,
2.0531, 2.0722, 2.6879, 2.6579, 1.6239, 1.9451, 2.6679, 3.1080,
2.5605, 2.4904, 2.1577, 2.3007, 3.4252, 3.1737, 2.2489, 2.5733,
2.3683, 1.9356, 2.7642, 2.2878, 2.1554, 2.0033, 3.4997, 2.2529,
2.2691, 2.1176, 2.6994, 2.4119, 2.2028, 3.3839, 2.3140, 2.1525,
2.5463, 2.2762, 2.3282, 2.2763, 2.5435, 2.0169, 2.4871, 2.1349,
2.7200, 3.1535, 3.1368, 2.0324, 2.0856, 2.0134, 2.4037, 2.7132,
2.3139, 2.1486, -0.2466, 0.0903, -0.1047, 0.2029, -0.0666, 7.8326],
requires_grad=True) torch.Size([136])
auto_scale tensor([0.5493, 0.8645, 0.7676, 0.8458, 0.9359, 0.9723, 0.8601, 0.8278, 0.8688,
0.8065, 0.8890, 0.8186, 0.8317, 0.8664, 0.9159, 0.8246, 0.8372, 0.7962,
0.8844, 0.7761, 0.8720, 0.7950, 0.7659, 0.8668, 0.8509, 0.9081, 0.8105,
0.8173, 0.8717, 0.8438, 0.8572, 0.9669, 0.8819, 0.8530, 0.7883, 0.8392,
0.8440, 0.8996, 0.8765, 0.8916, 0.8976, 0.8247, 0.8810, 0.8846, 0.8333,
0.8479, 0.7473, 0.8615, 0.8197, 0.8579, 0.8787, 0.8492, 0.8801, 0.8171,
0.8782, 0.8881, 0.9359, 0.7999, 0.9454, 0.8929, 0.7871, 0.8273, 0.8825,
0.9139, 0.8244, 0.8505, 0.8168, 0.7927, 0.8230, 0.8653, 0.7898, 0.9207,
0.8871, 0.7965, 0.8993, 0.8590, 0.9090, 0.8746, 0.8706, 0.8929, 0.8608,
0.8351, 0.8591, 0.8334, 0.8578, 0.8923, 0.8580, 0.7879, 0.7564, 0.8824,
0.8359, 0.9050, 0.8133, 0.8108, 0.9384, 0.7770, 0.8682, 0.9074, 0.8448,
0.7789, 0.8541, 0.9117, 0.7664, 0.8085, 0.9002, 0.8371, 0.8881, 0.8344,
0.8924, 0.7709, 0.7925, 0.8309, 0.8507, 0.8539, 0.8713, 0.8328, 0.8393,
0.9083, 0.8275, 0.9502, 0.7257, 0.8732, 0.8322, 0.8320, 0.8606, 0.8345,
0.8317, 0.8608, 0.8084, 0.8753, 0.1912, 0.1554, 0.2356, 0.1696, 0.1852,
0.6428], grad_fn=<AddBackward0>) torch.Size([136])
The loss values and parameters are very different but theoretically isn’t using obs
and using pyro.condition
just two syntactically different ways of doing the same thing i.e conditioning on observed data? So my 2 questions are:
- Where am I going wrong: is my implementation wrong somewhere or have I misunderstood obs and pyro.condition?
- My
x_data
has shape3181 x 134
but when I check my inferred parameters is saystorch.size[136]
. I can understand one extra value for the intercept to make it135
but where’s the136
th value coming from?