I met an exception:
RuntimeError: Expected p_in >= 0 && p_in <= 1 to be true, but got false.
The exception occurs in:
~/anaconda3/lib/python3.6/site-packages/torch/distributions/bernoulli.py in sample(self, sample_shape)
88 with torch.no_grad():
---> 89 return torch.bernoulli(self.probs.expand(shape))
90
I am using the SVI to inference the probability of Bernoulli. But I have used constraints.interval(0,1) with the pyro.param(). Why is it still out of the range?
My model:
def model_fn(observations):
def model_pcs(probs):
# probs is a sequence of probabilities. E.g. [0.6, 0.1, 0.1, 0.1, 0.1]
obs = torch.from_numpy(observations)
abnormal_pcs = []
for i in range(len(probs)):
# 0 means not abnormal, 1 means abnormal
abnormal_pcs.append(pyro.sample("abnormal_pc{}".format(i),dist.Bernoulli(probs[i])))
# E,g, abnormal_pcs[0,1,0,0,1] means pc_0, pc_2, pc_3 is normal but pc_1, pc_4 is abnormal
normal_probs = torch.from_numpy(probs*(1.0-np.array(abnormal_pcs)))
abnormal_probs = torch.from_numpy(probs*np.array(abnormal_pcs))
a = 0.5
is_abnormal = pyro.sample("is_abnormal",dist.Bernoulli(a))
# 50% to be normal, sample a pc with categorical distribution from normal pcs
# 50% to be abnormal, sample a pc with categorical distribution from abnormal pcs
for i in range(len(observations)):
if is_abnormal!=1.0:
# normal
pc = pyro.sample('pc_{}'.format(i),dist.Categorical(normal_probs),obs=obs[i])
else:
# abnormal
pc = pyro.sample('pc_{}'.format(i), dist.Categorical(abnormal_probs),obs=obs[i])
return model_pcs
The guide function is:
def param_guide(probs):
inferenced_abnormal_pc = []
for i in range(len(probs)):
inferenced_abnormal_pc.append(pyro.param("es_prob_pc{}",format(i),torch.tensor(prob[i]),constraint=constraints.interval(.0,1.)))
abnormal_pcs = []
for i in range(len(probs)):
# 0 means not abnormal, 1 means abnormal
abnormal_pcs.append(pyro.sample("abnormal_pc{}".format(i),dist.Bernoulli(inferenced_abnormal_pc[i])))
normal_probs = torch.from_numpy(probs*(1.0-np.array(abnormal_pcs)))
abnormal_probs = torch.from_numpy(probs*np.array(abnormal_pcs))
anomaly_score = pyro.param("anomaly_score",torch.tensor(.5),constraint=constraints.interval(.0,1.))
is_abnormal = pyro.sample("is_abnormal",dist.Bernoulli(anomaly_score))