Not sure why I'm getting this index error

So I’m trying to run a poisson regression model with 81 variables. I transformed the predictor array to a tensor of size torch.Size([657878, 81]). I’m not sure if my syntax is wrong or something else.

Traceback (most recent call last):

  File "<ipython-input-61-b895b742b8a3>", line 4, in <module>
    elbo = svi.step(x_discrete, y_discrete)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\svi.py", line 128, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\trace_elbo.py", line 126, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\elbo.py", line 170, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\trace_elbo.py", line 53, in _get_trace
    "flat", self.max_plate_nesting, model, guide, args, kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\enum.py", line 48, in get_importance_trace
    graph_type=graph_type).get_trace(*args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 187, in get_trace
    self(*args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 165, in __call__
    ret = self.fn(*args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)

  File "<ipython-input-58-1563a4508bc2>", line 161, in model
    x_data['fpc_model_10']*b76+

IndexError: too many indices for tensor of dimension 2

Here is my model and guide:



def model(x_data, y_data):
    a = pyro.sample('intercept', dist.Normal(0.,1.))
    b0 = pyro.sample("vintile_model_1", dist.Normal(0.,1.))
    b1 = pyro.sample("vintile_model_2", dist.Normal(0.,1.))
    b2 = pyro.sample("vintile_model_4", dist.Normal(0.,1.))
    b3 = pyro.sample("vintile_model_5", dist.Normal(0.,1.))
    b4 = pyro.sample("vintile_model_6", dist.Normal(0.,1.))
    b5 = pyro.sample("vintile_model_7", dist.Normal(0.,1.))
    b6 = pyro.sample("vintile_model_8", dist.Normal(0.,1.))
    b7 = pyro.sample("vintile_model_9", dist.Normal(0.,1.))
    b8 = pyro.sample("vintile_model_10", dist.Normal(0.,1.))
    b9 = pyro.sample("vintile_model_11", dist.Normal(0.,1.))
    b10 = pyro.sample("vintile_model_12", dist.Normal(0.,1.))
    b11 = pyro.sample("vintile_model_13", dist.Normal(0.,1.))
    b12 = pyro.sample("vintile_model_14", dist.Normal(0.,1.))
    b13 = pyro.sample("vintile_model_15", dist.Normal(0.,1.))
    b14 = pyro.sample("vintile_model_16", dist.Normal(0.,1.))
    b15 = pyro.sample("vintile_model_17", dist.Normal(0.,1.))
    b16 = pyro.sample("vintile_model_18", dist.Normal(0.,1.))
    b17 = pyro.sample("vintile_model_19", dist.Normal(0.,1.))
    b18 = pyro.sample("vintile_model_20", dist.Normal(0.,1.))
    b19 = pyro.sample("limit_01_model_discrete_100000", dist.Normal(0.,1.))
    b20 = pyro.sample("limit_01_model_discrete_150000", dist.Normal(0.,1.))
    b21 = pyro.sample("limit_01_model_discrete_200000", dist.Normal(0.,1.))
    b22 = pyro.sample("limit_01_model_discrete_250000", dist.Normal(0.,1.))
    b23 = pyro.sample("limit_01_model_discrete_300000", dist.Normal(0.,1.))
    b24 = pyro.sample("limit_01_model_discrete_400000", dist.Normal(0.,1.))
    b25 = pyro.sample("limit_01_model_discrete_500000", dist.Normal(0.,1.))
    b26 = pyro.sample("deda_model_500", dist.Normal(0.,1.))
    b27 = pyro.sample("deda_model_2500", dist.Normal(0.,1.))
    b28 = pyro.sample("deda_model_4000", dist.Normal(0.,1.))
    b29 = pyro.sample("deda_model_5000", dist.Normal(0.,1.))
    b30 = pyro.sample("deda_model_10000", dist.Normal(0.,1.))
    b31 = pyro.sample("aoh_model_discrete_20", dist.Normal(0.,1.))
    b32 = pyro.sample("aoh_model_discrete_30", dist.Normal(0.,1.))
    b33 = pyro.sample("aoh_model_discrete_40", dist.Normal(0.,1.))
    b34 = pyro.sample("aoh_model_discrete_50", dist.Normal(0.,1.))
    b35 = pyro.sample("aoh_model_discrete_60", dist.Normal(0.,1.))
    b36 = pyro.sample("aoh_model_discrete_70", dist.Normal(0.,1.))
    b37 = pyro.sample("aoh_model_discrete_80", dist.Normal(0.,1.))
    b38 = pyro.sample("aoh_model_discrete_90", dist.Normal(0.,1.))
    b39 = pyro.sample("aoh_model_discrete_100", dist.Normal(0.,1.))
    b40 = pyro.sample("insured_age_model_discrete_40", dist.Normal(0.,1.))
    b41 = pyro.sample("insured_age_model_discrete_50", dist.Normal(0.,1.))
    b42 = pyro.sample("insured_age_model_discrete_60", dist.Normal(0.,1.))
    b43 = pyro.sample("insured_age_model_discrete_70", dist.Normal(0.,1.))
    b44 = pyro.sample("insured_age_model_discrete_80", dist.Normal(0.,1.))
    b45 = pyro.sample("families_model_2", dist.Normal(0.,1.))
    b46 = pyro.sample("credit_score_1", dist.Normal(0.,1.))
    b47 = pyro.sample("credit_score_2", dist.Normal(0.,1.))
    b48 = pyro.sample("credit_score_4", dist.Normal(0.,1.))
    b49 = pyro.sample("credit_score_375", dist.Normal(0.,1.))
    b50 = pyro.sample("credit_score_399", dist.Normal(0.,1.))
    b51 = pyro.sample("credit_score_424", dist.Normal(0.,1.))
    b52 = pyro.sample("credit_score_449", dist.Normal(0.,1.))
    b53 = pyro.sample("credit_score_474", dist.Normal(0.,1.))
    b54 = pyro.sample("credit_score_499", dist.Normal(0.,1.))
    b55 = pyro.sample("credit_score_524", dist.Normal(0.,1.))
    b56 = pyro.sample("credit_score_549", dist.Normal(0.,1.))
    b57 = pyro.sample("credit_score_574", dist.Normal(0.,1.))
    b58 = pyro.sample("credit_score_599", dist.Normal(0.,1.))
    b59 = pyro.sample("credit_score_649", dist.Normal(0.,1.))
    b60 = pyro.sample("credit_score_699", dist.Normal(0.,1.))
    b61 = pyro.sample("credit_score_700", dist.Normal(0.,1.))
    b62 = pyro.sample("credit_score_724", dist.Normal(0.,1.))
    b63 = pyro.sample("credit_score_749", dist.Normal(0.,1.))
    b64 = pyro.sample("credit_score_774", dist.Normal(0.,1.))
    b65 = pyro.sample("credit_score_799", dist.Normal(0.,1.))
    b66 = pyro.sample("credit_score_800", dist.Normal(0.,1.))
    b67 = pyro.sample("channel_DIR", dist.Normal(0.,1.))
    b68 = pyro.sample("fpc_model_2", dist.Normal(0.,1.))
    b69 = pyro.sample("fpc_model_3", dist.Normal(0.,1.))
    b70 = pyro.sample("fpc_model_4", dist.Normal(0.,1.))
    b71 = pyro.sample("fpc_model_5", dist.Normal(0.,1.))
    b72 = pyro.sample("fpc_model_6", dist.Normal(0.,1.))
    b73 = pyro.sample("fpc_model_7", dist.Normal(0.,1.))
    b74 = pyro.sample("fpc_model_8", dist.Normal(0.,1.))
    b75 = pyro.sample("fpc_model_9", dist.Normal(0.,1.))
    b76 = pyro.sample("fpc_model_10", dist.Normal(0.,1.))
    b77 = pyro.sample("fpc_model_80", dist.Normal(0.,1.))
    b78 = pyro.sample("prev_fire_cnt_1", dist.Normal(0.,1.))
    b79 = pyro.sample("prev_waterplm_cnt_1", dist.Normal(0.,1.))
    b80 = pyro.sample("prev_Other_cnt_1", dist.Normal(0.,1.))
    
    mean = (a + x_data['vintile_model_1']*b0+
    x_data['vintile_model_2']*b1+
    x_data['vintile_model_4']*b2+
    x_data['vintile_model_5']*b3+
    x_data['vintile_model_6']*b4+
    x_data['vintile_model_7']*b5+
    x_data['vintile_model_8']*b6+
    x_data['vintile_model_9']*b7+
    x_data['vintile_model_10']*b8+
    x_data['vintile_model_11']*b9+
    x_data['vintile_model_12']*b10+
    x_data['vintile_model_13']*b11+
    x_data['vintile_model_14']*b12+
    x_data['vintile_model_15']*b13+
    x_data['vintile_model_16']*b14+
    x_data['vintile_model_17']*b15+
    x_data['vintile_model_18']*b16+
    x_data['vintile_model_19']*b17+
    x_data['vintile_model_20']*b18+
    x_data['limit_01_model_discrete_100000']*b19+
    x_data['limit_01_model_discrete_150000']*b20+
    x_data['limit_01_model_discrete_200000']*b21+
    x_data['limit_01_model_discrete_250000']*b22+
    x_data['limit_01_model_discrete_300000']*b23+
    x_data['limit_01_model_discrete_400000']*b24+
    x_data['limit_01_model_discrete_500000']*b25+
    x_data['deda_model_500']*b26+
    x_data['deda_model_2500']*b27+
    x_data['deda_model_4000']*b28+
    x_data['deda_model_5000']*b29+
    x_data['deda_model_10000']*b30+
    x_data['aoh_model_discrete_20']*b31+
    x_data['aoh_model_discrete_30']*b32+
    x_data['aoh_model_discrete_40']*b33+
    x_data['aoh_model_discrete_50']*b34+
    x_data['aoh_model_discrete_60']*b35+
    x_data['aoh_model_discrete_70']*b36+
    x_data['aoh_model_discrete_80']*b37+
    x_data['aoh_model_discrete_90']*b38+
    x_data['aoh_model_discrete_100']*b39+
    x_data['insured_age_model_discrete_40']*b40+
    x_data['insured_age_model_discrete_50']*b41+
    x_data['insured_age_model_discrete_60']*b42+
    x_data['insured_age_model_discrete_70']*b43+
    x_data['insured_age_model_discrete_80']*b44+
    x_data['families_model_2']*b45+
    x_data['credit_score_1']*b46+
    x_data['credit_score_2']*b47+
    x_data['credit_score_4']*b48+
    x_data['credit_score_375']*b49+
    x_data['credit_score_399']*b50+
    x_data['credit_score_424']*b51+
    x_data['credit_score_449']*b52+
    x_data['credit_score_474']*b53+
    x_data['credit_score_499']*b54+
    x_data['credit_score_524']*b55+
    x_data['credit_score_549']*b56+
    x_data['credit_score_574']*b57+
    x_data['credit_score_599']*b58+
    x_data['credit_score_649']*b59+
    x_data['credit_score_699']*b60+
    x_data['credit_score_700']*b61+
    x_data['credit_score_724']*b62+
    x_data['credit_score_749']*b63+
    x_data['credit_score_774']*b64+
    x_data['credit_score_799']*b65+
    x_data['credit_score_800']*b66+
    x_data['channel_DIR']*b67+
    x_data['fpc_model_2']*b68+
    x_data['fpc_model_3']*b69+
    x_data['fpc_model_4']*b70+
    x_data['fpc_model_5']*b71+
    x_data['fpc_model_6']*b72+
    x_data['fpc_model_7']*b73+
    x_data['fpc_model_8']*b74+
    x_data['fpc_model_9']*b75+
    x_data['fpc_model_10']*b76+
    x_data['fpc_model_80']*b77+
    x_data['prev_fire_cnt_1']*b78+
    x_data['prev_waterplm_cnt_1']*b79+
    x_data['prev_Other_cnt_1']*b80)
    rate = mean.exp()
    with pyro.plate("data", len(x_data), dim = -2):
        pyro.sample("obs", dist.Poisson(rate), obs=y_data)  
        
# n = 0
# for x in x_discrete.columns:
#     print('b'+str(n)+'= pyro.sample("'+x+'", dist.Normal(weights_loc['+str(n)+'], weights_scale['+str(n)+']))')
#     n += 1
        
def guide(x_data, y_data):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.), 
                         constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', torch.tensor(1.),
                           constraint=constraints.positive)
    weights_loc = pyro.param('weights_loc', torch.randn(81))
    weights_scale = pyro.param('weights_scale', torch.ones(81),
                               constraint=constraints.positive)    
    a = pyro.sample("intercept",
                    dist.Normal(a_loc, a_scale))
    b0= pyro.sample("vintile_model_1",
                    dist.Normal(weights_loc[0], weights_scale[0]))
    b1= pyro.sample("vintile_model_2", 
                    dist.Normal(weights_loc[1], weights_scale[1]))
    b2= pyro.sample("vintile_model_4", 
                    dist.Normal(weights_loc[2], weights_scale[2]))
    b3= pyro.sample("vintile_model_5", 
                    dist.Normal(weights_loc[3], weights_scale[3]))
    b4= pyro.sample("vintile_model_6", 
                    dist.Normal(weights_loc[4], weights_scale[4]))
    b5= pyro.sample("vintile_model_7", 
                    dist.Normal(weights_loc[5], weights_scale[5]))
    b6= pyro.sample("vintile_model_8", 
                    dist.Normal(weights_loc[6], weights_scale[6]))
    b7= pyro.sample("vintile_model_9", 
                    dist.Normal(weights_loc[7], weights_scale[7]))
    b8= pyro.sample("vintile_model_10",
                    dist.Normal(weights_loc[8], weights_scale[8]))
    b9= pyro.sample("vintile_model_11",
                    dist.Normal(weights_loc[9], weights_scale[9]))
    b10= pyro.sample("vintile_model_12",
                     dist.Normal(weights_loc[10], weights_scale[10]))
    b11= pyro.sample("vintile_model_13",
                     dist.Normal(weights_loc[11], weights_scale[11]))
    b12= pyro.sample("vintile_model_14",
                     dist.Normal(weights_loc[12], weights_scale[12]))
    b13= pyro.sample("vintile_model_15",
                     dist.Normal(weights_loc[13], weights_scale[13]))
    b14= pyro.sample("vintile_model_16",
                     dist.Normal(weights_loc[14], weights_scale[14]))
    b15= pyro.sample("vintile_model_17",
                     dist.Normal(weights_loc[15], weights_scale[15]))
    b16= pyro.sample("vintile_model_18",
                     dist.Normal(weights_loc[16], weights_scale[16]))
    b17= pyro.sample("vintile_model_19",
                     dist.Normal(weights_loc[17], weights_scale[17]))
    b18= pyro.sample("vintile_model_20",
                     dist.Normal(weights_loc[18], weights_scale[18]))
    b19= pyro.sample("limit_01_model_discrete_100000",
                     dist.Normal(weights_loc[19], weights_scale[19]))
    b20= pyro.sample("limit_01_model_discrete_150000",
                     dist.Normal(weights_loc[20], weights_scale[20]))
    b21= pyro.sample("limit_01_model_discrete_200000",
                     dist.Normal(weights_loc[21], weights_scale[21]))
    b22= pyro.sample("limit_01_model_discrete_250000",
                     dist.Normal(weights_loc[22], weights_scale[22]))
    b23= pyro.sample("limit_01_model_discrete_300000",
                     dist.Normal(weights_loc[23], weights_scale[23]))
    b24= pyro.sample("limit_01_model_discrete_400000",
                     dist.Normal(weights_loc[24], weights_scale[24]))
    b25= pyro.sample("limit_01_model_discrete_500000",
                     dist.Normal(weights_loc[25], weights_scale[25]))
    b26= pyro.sample("deda_model_500", 
                     dist.Normal(weights_loc[26], weights_scale[26]))
    b27= pyro.sample("deda_model_2500", 
                     dist.Normal(weights_loc[27], weights_scale[27]))
    b28= pyro.sample("deda_model_4000", 
                     dist.Normal(weights_loc[28], weights_scale[28]))
    b29= pyro.sample("deda_model_5000", 
                     dist.Normal(weights_loc[29], weights_scale[29]))
    b30= pyro.sample("deda_model_10000", 
                     dist.Normal(weights_loc[30], weights_scale[30]))
    b31= pyro.sample("aoh_model_discrete_20", 
                     dist.Normal(weights_loc[31], weights_scale[31]))
    b32= pyro.sample("aoh_model_discrete_30", 
                     dist.Normal(weights_loc[32], weights_scale[32]))
    b33= pyro.sample("aoh_model_discrete_40", 
                     dist.Normal(weights_loc[33], weights_scale[33]))
    b34= pyro.sample("aoh_model_discrete_50", 
                     dist.Normal(weights_loc[34], weights_scale[34]))
    b35= pyro.sample("aoh_model_discrete_60", 
                     dist.Normal(weights_loc[35], weights_scale[35]))
    b36= pyro.sample("aoh_model_discrete_70", 
                     dist.Normal(weights_loc[36], weights_scale[36]))
    b37= pyro.sample("aoh_model_discrete_80", 
                     dist.Normal(weights_loc[37], weights_scale[37]))
    b38= pyro.sample("aoh_model_discrete_90", 
                     dist.Normal(weights_loc[38], weights_scale[38]))
    b39= pyro.sample("aoh_model_discrete_100", 
                     dist.Normal(weights_loc[39], weights_scale[39]))
    b40= pyro.sample("insured_age_model_discrete_40",
                     dist.Normal(weights_loc[40], weights_scale[40]))
    b41= pyro.sample("insured_age_model_discrete_50",
                     dist.Normal(weights_loc[41], weights_scale[41]))
    b42= pyro.sample("insured_age_model_discrete_60",
                     dist.Normal(weights_loc[42], weights_scale[42]))
    b43= pyro.sample("insured_age_model_discrete_70",
                     dist.Normal(weights_loc[43], weights_scale[43]))
    b44= pyro.sample("insured_age_model_discrete_80",
                     dist.Normal(weights_loc[44], weights_scale[44]))
    b45= pyro.sample("families_model_2",
                     dist.Normal(weights_loc[45], weights_scale[45]))
    b46= pyro.sample("credit_score_1",
                     dist.Normal(weights_loc[46], weights_scale[46]))
    b47= pyro.sample("credit_score_2",
                     dist.Normal(weights_loc[47], weights_scale[47]))
    b48= pyro.sample("credit_score_4",
                     dist.Normal(weights_loc[48], weights_scale[48]))
    b49= pyro.sample("credit_score_375",
                     dist.Normal(weights_loc[49], weights_scale[49]))
    b50= pyro.sample("credit_score_399",
                     dist.Normal(weights_loc[50], weights_scale[50]))
    b51= pyro.sample("credit_score_424", 
                     dist.Normal(weights_loc[51], weights_scale[51]))
    b52= pyro.sample("credit_score_449",
                     dist.Normal(weights_loc[52], weights_scale[52]))
    b53= pyro.sample("credit_score_474",
                     dist.Normal(weights_loc[53], weights_scale[53]))
    b54= pyro.sample("credit_score_499",
                     dist.Normal(weights_loc[54], weights_scale[54]))
    b55= pyro.sample("credit_score_524",
                     dist.Normal(weights_loc[55], weights_scale[55]))
    b56= pyro.sample("credit_score_549",
                     dist.Normal(weights_loc[56], weights_scale[56]))
    b57= pyro.sample("credit_score_574",
                     dist.Normal(weights_loc[57], weights_scale[57]))
    b58= pyro.sample("credit_score_599", 
                     dist.Normal(weights_loc[58], weights_scale[58]))
    b59= pyro.sample("credit_score_649",
                     dist.Normal(weights_loc[59], weights_scale[59]))
    b60= pyro.sample("credit_score_699",
                     dist.Normal(weights_loc[60], weights_scale[60]))
    b61= pyro.sample("credit_score_700",
                     dist.Normal(weights_loc[61], weights_scale[61]))
    b62= pyro.sample("credit_score_724", 
                     dist.Normal(weights_loc[62], weights_scale[62]))
    b63= pyro.sample("credit_score_749",
                     dist.Normal(weights_loc[63], weights_scale[63]))
    b64= pyro.sample("credit_score_774", 
                     dist.Normal(weights_loc[64], weights_scale[64]))
    b65= pyro.sample("credit_score_799",
                     dist.Normal(weights_loc[65], weights_scale[65]))
    b66= pyro.sample("credit_score_800",
                     dist.Normal(weights_loc[66], weights_scale[66]))
    b67= pyro.sample("channel_DIR",
                     dist.Normal(weights_loc[67], weights_scale[67]))
    b68= pyro.sample("fpc_model_2",
                     dist.Normal(weights_loc[68], weights_scale[68]))
    b69= pyro.sample("fpc_model_3",
                     dist.Normal(weights_loc[69], weights_scale[69]))
    b70= pyro.sample("fpc_model_4",
                     dist.Normal(weights_loc[70], weights_scale[70]))
    b71= pyro.sample("fpc_model_5",
                     dist.Normal(weights_loc[71], weights_scale[71]))
    b72= pyro.sample("fpc_model_6", 
                     dist.Normal(weights_loc[72], weights_scale[72]))
    b73= pyro.sample("fpc_model_7",
                     dist.Normal(weights_loc[73], weights_scale[73]))
    b74= pyro.sample("fpc_model_8",
                     dist.Normal(weights_loc[74], weights_scale[74]))
    b75= pyro.sample("fpc_model_9",
                     dist.Normal(weights_loc[75], weights_scale[75]))
    b76= pyro.sample("fpc_model_10",
                     dist.Normal(weights_loc[76], weights_scale[76]))
    b77= pyro.sample("fpc_model_80",
                     dist.Normal(weights_loc[77], weights_scale[77]))
    b78= pyro.sample("prev_fire_cnt_1",
                      dist.Normal(weights_loc[78], weights_scale[78]))
    b79= pyro.sample("prev_waterplm_cnt_1", 
                      dist.Normal(weights_loc[79], weights_scale[79]))
    b80= pyro.sample("prev_Other_cnt_1",
                      dist.Normal(weights_loc[80], weights_scale[80]))
    sigma = pyro.sample("sigma", dist.Uniform(sigma_loc, 5.))
    

def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

svi = SVI(model,
          guide,
          optim.Adam({"lr": .05}),
          loss=Trace_ELBO())


pyro.clear_param_store()
num_iters = 5000 
for i in range(num_iters):
    elbo = svi.step(x_discrete, y_discrete)
    if i % 500 == 0:
        logging.info("Elbo loss: {}".format(elbo))

Could it be how I set up my formula? Maybe since x_data is now a tensor from a dataframe, it doesn’t recognize column names…since they aren’t there? I tried x_data[:, 0] but that didn’t work either.

Update:

I changed this model to only have on variable and still get the same error:

def model(x_data, y_data):
    a = pyro.sample('intercept', dist.Normal(0.,1.))
    b0 = pyro.sample("vintile_model_1", dist.Normal(0.,1.))
    mean = a +  x_data['vintile_model_1'].values*b0
    rate = mean.exp()
    with pyro.plate("data", len(x_data), dim = -2):
        pyro.sample("obs", dist.Poisson(rate), obs=y_data)  
        
       
def guide(x_data, y_data):
    a_loc = pyro.param('a_loc', torch.tensor(0.))
    a_scale = pyro.param('a_scale', torch.tensor(1.),
                         constraint=constraints.positive)
    sigma_loc = pyro.param('sigma_loc', torch.tensor(1.),
                             constraint=constraints.positive)
    weights_loc = pyro.param('weights_loc', torch.randn(1))
    weights_scale = pyro.param('weights_scale', torch.ones(1),
                               constraint=constraints.positive)
    a = pyro.sample("a", dist.Normal(a_loc, a_scale))
    b0 = pyro.sample("b0", dist.Normal(weights_loc[0], weights_scale[0]))

    sigma = pyro.sample("sigma", dist.Uniform(sigma_loc, 5.))
    

def summary(samples):
    site_stats = {}
    for site_name, values in samples.items():
        marginal_site = pd.DataFrame(values)
        describe = marginal_site.describe(percentiles=[.05, 0.25, 0.5, 0.75, 0.95]).transpose()
        site_stats[site_name] = describe[["mean", "std", "5%", "25%", "50%", "75%", "95%"]]
    return site_stats

svi = SVI(model,
          guide,
          optim.Adam({"lr": .05}),
          loss=Trace_ELBO())
Traceback (most recent call last):

  File "<ipython-input-97-b895b742b8a3>", line 4, in <module>
    elbo = svi.step(x_discrete, y_discrete)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\svi.py", line 128, in step
    loss = self.loss_and_grads(self.model, self.guide, *args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\trace_elbo.py", line 126, in loss_and_grads
    for model_trace, guide_trace in self._get_traces(model, guide, args, kwargs):

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\elbo.py", line 170, in _get_traces
    yield self._get_trace(model, guide, args, kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\trace_elbo.py", line 53, in _get_trace
    "flat", self.max_plate_nesting, model, guide, args, kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\infer\enum.py", line 48, in get_importance_trace
    graph_type=graph_type).get_trace(*args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 187, in get_trace
    self(*args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\trace_messenger.py", line 165, in __call__
    ret = self.fn(*args, **kwargs)

  File "C:\Users\JORDAN.HOWELL.GITDIR\AppData\Local\Continuum\anaconda3\envs\torch_env\lib\site-packages\pyro\poutine\messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)

  File "<ipython-input-93-582f399d5cec>", line 84, in model
    mean = a +  x_data['vintile_model_1'].values*b0

IndexError: too many indices for tensor of dimension 2