I’m trying to define my own distribution with its own sample and log_prob methods. The distribution is supposed to model the noise in an observed k-mer (string of nucleotides A,C,G,T). After importing the class, I can use pyro.sample(“obs”, NoiseModel(x))) to generate a noisy version of k-mer x (e.g., if x=ACGT then I might sample ACGA). However, when I define my pyro model and try to do inference with SVI, I get the error AttributeError: ‘function’ object has no attribute ‘log_prob’. I think it might have to do with the fact that in the second case I am passing the optional argument data. However, I am not sure how this argument is messing with my log_prob method. Thanks in advance!
###############################################################
# Class definition (mymodule.py)
###############################################################
import numpy as np
from pyro.distributions import Distribution
# Constants
NUCSET = set('ACGT')
# Noise model for k-mers
class NoiseModel(Distribution):
def __init__(self,bin_mean_kmer):
self.bin_mean_kmer = bin_mean_kmer
def sample(self,sample_shape=torch.Size()):
noisy_kmer = ''
kmer = dec_seq(self.bin_mean_kmer)
for nuc in kmer:
if np.random.sample()<=0.8:
noisy_kmer += nuc
else:
noisy_kmer += np.random.choice(tuple(NUCSET.difference(set(nuc))))
return noisy_kmer
def log_prob(self,value):
log_p = 0.0
kmer = dec_seq(self.bin_mean_kmer)
for i in range(len(kmer)):
log_p += np.log(0.8) if kmer[i]==value[i] else np.log(0.2)
return log_p
# Decode k-mer
def dec_seq(x):
y = ''
for i in range(0,len(x),2):
if x[i]==0 and x[i+1]==0:
y = y + 'A'
elif x[i]==0 and x[i+1]==1:
y = y + 'C'
elif x[i]==1 and x[i+1]==0:
y = y + 'G'
else:
y = y + 'T'
return y
# Encode k-mer
def enc_seq(y):
x = [0] * (2*len(y))
for i in range(len(y)):
j = 2 * i
if y[i]=='A':
x[j]=0
x[j+1]=0
elif y[i]=='C':
x[j]=0
x[j+1]=1
elif y[i]=='G':
x[j]=1
x[j+1]=0
else:
x[j]=1
x[j+1]=1
return x
###############################################################
# Pyro model
###############################################################
from mymodule import NoiseModel,enc_seq,dec_seq
# Sample binary version of k-mer
def sample_bin_rep():
x = [0] * (2*K)
for i in range(2*K):
x[i] = 1 if np.random.uniform()>0.5 else 0
return torch.tensor(x)
# Stick-breaking function
def mix_weights(beta):
beta1m_cumprod = (1 - beta).cumprod(-1)
return F.pad(beta, (0, 1), value=1) * F.pad(beta1m_cumprod, (1, 0), value=1)
# Define pyro model
def model(data):
with pyro.plate("beta_plate", T-1):
beta = pyro.sample("beta", Beta(1, alpha))
with pyro.plate("true_flank_plate", T):
true_flank = pyro.sample("true_flank", sample_bin_rep)
with pyro.plate("data", N):
z = pyro.sample("z", Categorical(mix_weights(beta)))
pyro.sample("obs", NoiseModel(true_flank[z]), obs=data)