So I’m trying to run a poisson regression model with 81 variables. I transformed the predictor array to a tensor of size torch.Size([657878, 81])
. I’m not sure if my syntax is wrong or something else.
Traceback (most recent call last):
File "<ipython-input-61-b895b742b8a3>", line 4, in <module>
elbo = svi.step(x_discrete, y_discrete)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\svi.py", line 128, in step
loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\trace_elbo.py", line 126, in loss_and_grads
for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\elbo.py", line 170, in _get_traces
yield self._get_trace(model, guide, args, kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\trace_elbo.py", line 53, in _get_trace
"flat", self.max_plate_nesting, model, guide, args, kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\enum.py", line 48, in get_importance_trace
graph_type=graph_type).get_trace(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 187, in get_trace
self(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 165, in __call__
ret = self.fn(*args, **kwargs)
File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
return fn(*args, **kwargs)
File "<ipython-input-58-1563a4508bc2>", line 161, in model
x_data['fpc_model_10']*b76+
IndexError: too many indices for tensor of dimension 2
Here is my model and guide:
def model(x_data, y_data):
a = pyro.sample('intercept', dist.Normal(0.,1.))
b0 = pyro.sample("vintile_model_1", dist.Normal(0.,1.))
b1 = pyro.sample("vintile_model_2", dist.Normal(0.,1.))
b2 = pyro.sample("vintile_model_4", dist.Normal(0.,1.))
b3 = pyro.sample("vintile_model_5", dist.Normal(0.,1.))
b4 = pyro.sample("vintile_model_6", dist.Normal(0.,1.))
b5 = pyro.sample("vintile_model_7", dist.Normal(0.,1.))
b6 = pyro.sample("vintile_model_8", dist.Normal(0.,1.))
b7 = pyro.sample("vintile_model_9", dist.Normal(0.,1.))
b8 = pyro.sample("vintile_model_10", dist.Normal(0.,1.))
b9 = pyro.sample("vintile_model_11", dist.Normal(0.,1.))
b10 = pyro.sample("vintile_model_12", dist.Normal(0.,1.))
b11 = pyro.sample("vintile_model_13", dist.Normal(0.,1.))
b12 = pyro.sample("vintile_model_14", dist.Normal(0.,1.))
b13 = pyro.sample("vintile_model_15", dist.Normal(0.,1.))
b14 = pyro.sample("vintile_model_16", dist.Normal(0.,1.))
b15 = pyro.sample("vintile_model_17", dist.Normal(0.,1.))
b16 = pyro.sample("vintile_model_18", dist.Normal(0.,1.))
b17 = pyro.sample("vintile_model_19", dist.Normal(0.,1.))
b18 = pyro.sample("vintile_model_20", dist.Normal(0.,1.))
b19 = pyro.sample("limit_01_model_discrete_100000", dist.Normal(0.,1.))
b20 = pyro.sample("limit_01_model_discrete_150000", dist.Normal(0.,1.))
b21 = pyro.sample("limit_01_model_discrete_200000", dist.Normal(0.,1.))
b22 = pyro.sample("limit_01_model_discrete_250000", dist.Normal(0.,1.))
b23 = pyro.sample("limit_01_model_discrete_300000", dist.Normal(0.,1.))
b24 = pyro.sample("limit_01_model_discrete_400000", dist.Normal(0.,1.))
b25 = pyro.sample("limit_01_model_discrete_500000", dist.Normal(0.,1.))
b26 = pyro.sample("deda_model_500", dist.Normal(0.,1.))
b27 = pyro.sample("deda_model_2500", dist.Normal(0.,1.))
b28 = pyro.sample("deda_model_4000", dist.Normal(0.,1.))
b29 = pyro.sample("deda_model_5000", dist.Normal(0.,1.))
b30 = pyro.sample("deda_model_10000", dist.Normal(0.,1.))
b31 = pyro.sample("aoh_model_discrete_20", dist.Normal(0.,1.))
b32 = pyro.sample("aoh_model_discrete_30", dist.Normal(0.,1.))
b33 = pyro.sample("aoh_model_discrete_40", dist.Normal(0.,1.))
b34 = pyro.sample("aoh_model_discrete_50", dist.Normal(0.,1.))
b35 = pyro.sample("aoh_model_discrete_60", dist.Normal(0.,1.))
b36 = pyro.sample("aoh_model_discrete_70", dist.Normal(0.,1.))
b37 = pyro.sample("aoh_model_discrete_80", dist.Normal(0.,1.))
b38 = pyro.sample("aoh_model_discrete_90", dist.Normal(0.,1.))
b39 = pyro.sample("aoh_model_discrete_100", dist.Normal(0.,1.))
b40 = pyro.sample("insured_age_model_discrete_40", dist.Normal(0.,1.))
b41 = pyro.sample("insured_age_model_discrete_50", dist.Normal(0.,1.))
b42 = pyro.sample("insured_age_model_discrete_60", dist.Normal(0.,1.))
b43 = pyro.sample("insured_age_model_discrete_70", dist.Normal(0.,1.))
b44 = pyro.sample("insured_age_model_discrete_80", dist.Normal(0.,1.))
b45 = pyro.sample("families_model_2", dist.Normal(0.,1.))
b46 = pyro.sample("credit_score_1", dist.Normal(0.,1.))
b47 = pyro.sample("credit_score_2", dist.Normal(0.,1.))
b48 = pyro.sample("credit_score_4", dist.Normal(0.,1.))
b49 = pyro.sample("credit_score_375", dist.Normal(0.,1.))
b50 = pyro.sample("credit_score_399", dist.Normal(0.,1.))
b51 = pyro.sample("credit_score_424", dist.Normal(0.,1.))
b52 = pyro.sample("credit_score_449", dist.Normal(0.,1.))
b53 = pyro.sample("credit_score_474", dist.Normal(0.,1.))
b54 = pyro.sample("credit_score_499", dist.Normal(0.,1.))
b55 = pyro.sample("credit_score_524", dist.Normal(0.,1.))
b56 = pyro.sample("credit_score_549", dist.Normal(0.,1.))
b57 = pyro.sample("credit_score_574", dist.Normal(0.,1.))
b58 = pyro.sample("credit_score_599", dist.Normal(0.,1.))
b59 = pyro.sample("credit_score_649", dist.Normal(0.,1.))
b60 = pyro.sample("credit_score_699", dist.Normal(0.,1.))
b61 = pyro.sample("credit_score_700", dist.Normal(0.,1.))
b62 = pyro.sample("credit_score_724", dist.Normal(0.,1.))
b63 = pyro.sample("credit_score_749", dist.Normal(0.,1.))
b64 = pyro.sample("credit_score_774", dist.Normal(0.,1.))
b65 = pyro.sample("credit_score_799", dist.Normal(0.,1.))
b66 = pyro.sample("credit_score_800", dist.Normal(0.,1.))
b67 = pyro.sample("channel_DIR", dist.Normal(0.,1.))
b68 = pyro.sample("fpc_model_2", dist.Normal(0.,1.))
b69 = pyro.sample("fpc_model_3", dist.Normal(0.,1.))
b70 = pyro.sample("fpc_model_4", dist.Normal(0.,1.))
b71 = pyro.sample("fpc_model_5", dist.Normal(0.,1.))
b72 = pyro.sample("fpc_model_6", dist.Normal(0.,1.))
b73 = pyro.sample("fpc_model_7", dist.Normal(0.,1.))
b74 = pyro.sample("fpc_model_8", dist.Normal(0.,1.))
b75 = pyro.sample("fpc_model_9", dist.Normal(0.,1.))
b76 = pyro.sample("fpc_model_10", dist.Normal(0.,1.))
b77 = pyro.sample("fpc_model_80", dist.Normal(0.,1.))
b78 = pyro.sample("prev_fire_cnt_1", dist.Normal(0.,1.))
b79 = pyro.sample("prev_waterplm_cnt_1", dist.Normal(0.,1.))
b80 = pyro.sample("prev_Other_cnt_1", dist.Normal(0.,1.))
mean = (a + x_data['vintile_model_1']*b0+
x_data['vintile_model_2']*b1+
x_data['vintile_model_4']*b2+
x_data['vintile_model_5']*b3+
x_data['vintile_model_6']*b4+
x_data['vintile_model_7']*b5+
x_data['vintile_model_8']*b6+
x_data['vintile_model_9']*b7+
x_data['vintile_model_10']*b8+
x_data['vintile_model_11']*b9+
x_data['vintile_model_12']*b10+
x_data['vintile_model_13']*b11+
x_data['vintile_model_14']*b12+
x_data['vintile_model_15']*b13+
x_data['vintile_model_16']*b14+
x_data['vintile_model_17']*b15+
x_data['vintile_model_18']*b16+
x_data['vintile_model_19']*b17+
x_data['vintile_model_20']*b18+
x_data['limit_01_model_discrete_100000']*b19+
x_data['limit_01_model_discrete_150000']*b20+
x_data['limit_01_model_discrete_200000']*b21+
x_data['limit_01_model_discrete_250000']*b22+
x_data['limit_01_model_discrete_300000']*b23+
x_data['limit_01_model_discrete_400000']*b24+
x_data['limit_01_model_discrete_500000']*b25+
x_data['deda_model_500']*b26+
x_data['deda_model_2500']*b27+
x_data['deda_model_4000']*b28+
x_data['deda_model_5000']*b29+
x_data['deda_model_10000']*b30+
x_data['aoh_model_discrete_20']*b31+
x_data['aoh_model_discrete_30']*b32+
x_data['aoh_model_discrete_40']*b33+
x_data['aoh_model_discrete_50']*b34+
x_data['aoh_model_discrete_60']*b35+
x_data['aoh_model_discrete_70']*b36+
x_data['aoh_model_discrete_80']*b37+
x_data['aoh_model_discrete_90']*b38+
x_data['aoh_model_discrete_100']*b39+
x_data['insured_age_model_discrete_40']*b40+
x_data['insured_age_model_discrete_50']*b41+
x_data['insured_age_model_discrete_60']*b42+
x_data['insured_age_model_discrete_70']*b43+
x_data['insured_age_model_discrete_80']*b44+
x_data['families_model_2']*b45+
x_data['credit_score_1']*b46+
x_data['credit_score_2']*b47+
x_data['credit_score_4']*b48+
x_data['credit_score_375']*b49+
x_data['credit_score_399']*b50+
x_data['credit_score_424']*b51+
x_data['credit_score_449']*b52+
x_data['credit_score_474']*b53+
x_data['credit_score_499']*b54+
x_data['credit_score_524']*b55+
x_data['credit_score_549']*b56+
x_data['credit_score_574']*b57+
x_data['credit_score_599']*b58+
x_data['credit_score_649']*b59+
x_data['credit_score_699']*b60+
x_data['credit_score_700']*b61+
x_data['credit_score_724']*b62+
x_data['credit_score_749']*b63+
x_data['credit_score_774']*b64+
x_data['credit_score_799']*b65+
x_data['credit_score_800']*b66+
x_data['channel_DIR']*b67+
x_data['fpc_model_2']*b68+
x_data['fpc_model_3']*b69+
x_data['fpc_model_4']*b70+
x_data['fpc_model_5']*b71+
x_data['fpc_model_6']*b72+
x_data['fpc_model_7']*b73+
x_data['fpc_model_8']*b74+
x_data['fpc_model_9']*b75+
x_data['fpc_model_10']*b76+
x_data['fpc_model_80']*b77+
x_data['prev_fire_cnt_1']*b78+
x_data['prev_waterplm_cnt_1']*b79+
x_data['prev_Other_cnt_1']*b80)
rate = mean.exp()
with pyro.plate("data", len(x_data), dim = -2):
pyro.sample("obs", dist.Poisson(rate), obs=y_data)
# n = 0
# for x in x_discrete.columns:
# print('b'+str(n)+'= pyro.sample("'+x+'", dist.Normal(weights_loc['+str(n)+'], weights_scale['+str(n)+']))')
# n += 1
def guide(x_data, y_data):
a_loc = pyro.param('a_loc', torch.tensor(0.))
a_scale = pyro.param('a_scale', torch.tensor(1.),
constraint=constraints.positive)
sigma_loc = pyro.param('sigma_loc', torch.tensor(1.),
constraint=constraints.positive)
weights_loc = pyro.param('weights_loc', torch.randn(81))
weights_scale = pyro.param('weights_scale', torch.ones(81),
constraint=constraints.positive)
a = pyro.sample("intercept",
dist.Normal(a_loc, a_scale))
b0= pyro.sample("vintile_model_1",
dist.Normal(weights_loc[0], weights_scale[0]))
b1= pyro.sample("vintile_model_2",
dist.Normal(weights_loc[1], weights_scale[1]))
b2= pyro.sample("vintile_model_4",
dist.Normal(weights_loc[2], weights_scale[2]))
b3= pyro.sample("vintile_model_5",
dist.Normal(weights_loc[3], weights_scale[3]))
b4= pyro.sample("vintile_model_6",
dist.Normal(weights_loc[4], weights_scale[4]))
b5= pyro.sample("vintile_model_7",
dist.Normal(weights_loc[5], weights_scale[5]))
b6= pyro.sample("vintile_model_8",
dist.Normal(weights_loc[6], weights_scale[6]))
b7= pyro.sample("vintile_model_9",
dist.Normal(weights_loc[7], weights_scale[7]))
b8= pyro.sample("vintile_model_10",
dist.Normal(weights_loc[8], weights_scale[8]))
b9= pyro.sample("vintile_model_11",
dist.Normal(weights_loc[9], weights_scale[9]))
b10= pyro.sample("vintile_model_12",
dist.Normal(weights_loc[10], weights_scale[10]))
b11= pyro.sample("vintile_model_13",
dist.Normal(weights_loc[11], weights_scale[11]))
b12= pyro.sample("vintile_model_14",
dist.Normal(weights_loc[12], weights_scale[12]))
b13= pyro.sample("vintile_model_15",
dist.Normal(weights_loc[13], weights_scale[13]))
b14= pyro.sample("vintile_model_16",
dist.Normal(weights_loc[14], weights_scale[14]))
b15= pyro.sample("vintile_model_17",
dist.Normal(weights_loc[15], weights_scale[15]))
b16= pyro.sample("vintile_model_18",
dist.Normal(weights_loc[16], weights_scale[16]))
b17= pyro.sample("vintile_model_19",
dist.Normal(weights_loc[17], weights_scale[17]))
b18= pyro.sample("vintile_model_20",
dist.Normal(weights_loc[18], weights_scale[18]))
b19= pyro.sample("limit_01_model_discrete_100000",
dist.Normal(weights_loc[19], weights_scale[19]))
b20= pyro.sample("limit_01_model_discrete_150000",
dist.Normal(weights_loc[20], weights_scale[20]))
b21= pyro.sample("limit_01_model_discrete_200000",
dist.Normal(weights_loc[21], weights_scale[21]))
b22= pyro.sample("limit_01_model_discrete_250000",
dist.Normal(weights_loc[22], weights_scale[22]))
b23= pyro.sample("limit_01_model_discrete_300000",
dist.Normal(weights_loc[23], weights_scale[23]))
b24= pyro.sample("limit_01_model_discrete_400000",
dist.Normal(weights_loc[24], weights_scale[24]))
b25= pyro.sample("limit_01_model_discrete_500000",
dist.Normal(weights_loc[25], weights_scale[25]))
b26= pyro.sample("deda_model_500",
dist.Normal(weights_loc[26], weights_scale[26]))
b27= pyro.sample("deda_model_2500",
dist.Normal(weights_loc[27], weights_scale[27]))
b28= pyro.sample("deda_model_4000",
dist.Normal(weights_loc[28], weights_scale[28]))
b29= pyro.sample("deda_model_5000",
dist.Normal(weights_loc[29], weights_scale[29]))
b30= pyro.sample("deda_model_10000",
dist.Normal(weights_loc[30], weights_scale[30]))
b31= pyro.sample("aoh_model_discrete_20",
dist.Normal(weights_loc[31], weights_scale[31]))
b32= pyro.sample("aoh_model_discrete_30",
dist.Normal(weights_loc[32], weights_scale[32]))
b33= pyro.sample("aoh_model_discrete_40",
dist.Normal(weights_loc[33], weights_scale[33]))
b34= pyro.sample("aoh_model_discrete_50",
dist.Normal(weights_loc[34], weights_scale[34]))
b35= pyro.sample("aoh_model_discrete_60",
dist.Normal(weights_loc[35], weights_scale[35]))
b36= pyro.sample("aoh_model_discrete_70",
dist.Normal(weights_loc[36], weights_scale[36]))
b37= pyro.sample("aoh_model_discrete_80",
dist.Normal(weights_loc[37], weights_scale[37]))
b38= pyro.sample("aoh_model_discrete_90",
dist.Normal(weights_loc[38], weights_scale[38]))
b39= pyro.sample("aoh_model_discrete_100",
dist.Normal(weights_loc[39], weights_scale[39]))
b40= pyro.sample("insured_age_model_discrete_40",
dist.Normal(weights_loc[40], weights_scale[40]))
b41= pyro.sample("insured_age_model_discrete_50",
dist.Normal(weights_loc[41], weights_scale[41]))
b42= pyro.sample("insured_age_model_discrete_60",
dist.Normal(weights_loc[42], weights_scale[42]))
b43= pyro.sample("insured_age_model_discrete_70",
dist.Normal(weights_loc[43], weights_scale[43]))
b44= pyro.sample("insured_age_model_discrete_80",
dist.Normal(weights_loc[44], weights_scale[44]))
b45= pyro.sample("families_model_2",
dist.Normal(weights_loc[45], weights_scale[45]))
b46= pyro.sample("credit_score_1",
dist.Normal(weights_loc[46], weights_scale[46]))
b47= pyro.sample("credit_score_2",
dist.Normal(weights_loc[47], weights_scale[47]))
b48= pyro.sample("credit_score_4",
dist.Normal(weights_loc[48], weights_scale[48]))
b49= pyro.sample("credit_score_375",
dist.Normal(weights_loc[49], weights_scale[49]))
b50= pyro.sample("credit_score_399",
dist.Normal(weights_loc[50], weights_scale[50]))
b51= pyro.sample("credit_score_424",
dist.Normal(weights_loc[51], weights_scale[51]))
b52= pyro.sample("credit_score_449",
dist.Normal(weights_loc[52], weights_scale[52]))
b53= pyro.sample("credit_score_474",
dist.Normal(weights_loc[53], weights_scale[53]))
b54= pyro.sample("credit_score_499",
dist.Normal(weights_loc[54], weights_scale[54]))
b55= pyro.sample("credit_score_524",
dist.Normal(weights_loc[55], weights_scale[55]))
b56= pyro.sample("credit_score_549",
dist.Normal(weights_loc[56], weights_scale[56]))
b57= pyro.sample("credit_score_574",
dist.Normal(weights_loc[57], weights_scale[57]))
b58= pyro.sample("credit_score_599",
dist.Normal(weights_loc[58], weights_scale[58]))
b59= pyro.sample("credit_score_649",
dist.Normal(weights_loc[59], weights_scale[59]))
b60= pyro.sample("credit_score_699",
dist.Normal(weights_loc[60], weights_scale[60]))
b61= pyro.sample("credit_score_700",
dist.Normal(weights_loc[61], weights_scale[61]))
b62= pyro.sample("credit_score_724",
dist.Normal(weights_loc[62], weights_scale[62]))
b63= pyro.sample("credit_score_749",
dist.Normal(weights_loc[63], weights_scale[63]))
b64= pyro.sample("credit_score_774",
dist.Normal(weights_loc[64], weights_scale[64]))
b65= pyro.sample("credit_score_799",
dist.Normal(weights_loc[65], weights_scale[65]))
b66= pyro.sample("credit_score_800",
dist.Normal(weights_loc[66], weights_scale[66]))
b67= pyro.sample("channel_DIR",
dist.Normal(weights_loc[67], weights_scale[67]))
b68= pyro.sample("fpc_model_2",
dist.Normal(weights_loc[68], weights_scale[68]))
b69= pyro.sample("fpc_model_3",
dist.Normal(weights_loc[69], weights_scale[69]))
b70= pyro.sample("fpc_model_4",
dist.Normal(weights_loc[70], weights_scale[70]))
b71= pyro.sample("fpc_model_5",
dist.Normal(weights_loc[71], weights_scale[71]))
b72= pyro.sample("fpc_model_6",
dist.Normal(weights_loc[72], weights_scale[72]))
b73= pyro.sample("fpc_model_7",
dist.Normal(weights_loc[73], weights_scale[73]))
b74= pyro.sample("fpc_model_8",
dist.Normal(weights_loc[74], weights_scale[74]))
b75= pyro.sample("fpc_model_9",
dist.Normal(weights_loc[75], weights_scale[75]))
b76= pyro.sample("fpc_model_10",
dist.Normal(weights_loc[76], weights_scale[76]))
b77= pyro.sample("fpc_model_80",
dist.Normal(weights_loc[77], weights_scale[77]))
b78= pyro.sample("prev_fire_cnt_1",
dist.Normal(weights_loc[78], weights_scale[78]))
b79= pyro.sample("prev_waterplm_cnt_1",
dist.Normal(weights_loc[79], weights_scale[79]))
b80= pyro.sample("prev_Other_cnt_1",
dist.Normal(weights_loc[80], weights_scale[80]))
sigma = pyro.sample("sigma", dist.Uniform(sigma_loc, 5.))
def summary(samples):
site_stats = {}
for site_name, values in samples.items():
marginal_site = pd.DataFrame(values)
describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
return site_stats
svi = SVI(model,
guide,
optim.Adam({"lr": .05}),
loss=Trace_ELBO())
pyro.clear_param_store()
num_iters = 5000
for i in range(num_iters):
elbo = svi.step(x_discrete, y_discrete)
if i % 500 == 0:
logging.info("Elbo loss: {}".format(elbo))
Could it be how I set up my formula? Maybe since x_data is now a tensor from a dataframe, it doesn’t recognize column names…since they aren’t there? I tried x_data[:, 0] but that didn’t work either.