Briefly, I’m working on a program that detects spots in the single-molecule fluorescence microscopy images and determines binding/dissociation rates using HMM. This is the model written using pyro.markov
:
@config_enumerate
def viterbi_model(self, data):
K_plate = pyro.plate("K_plate", self.K, dim=-2)
N_plate = pyro.plate("N_plate", data.N, dim=-1)
init = pi_theta_calc(param("pi"), self.K, self.S) # self.S*self.K+1
trans = theta_trans_calc(param("A"), self.K, self.S) # self.S*self.K+1, self.S*self.K+1
pi_m = pi_m_calc(param("lamda"), self.S) # self.S+1, self.S+1
with N_plate as batch_idx:
thetas = []
theta = pyro.sample("theta", dist.Categorical(init))
for f in pyro.markov(range(data.F)):
background = pyro.sample(
f"background_{f}", dist.Gamma(
param(f"d/background_loc")[batch_idx, 0]
* param("background_beta"), param("background_beta")))
theta = pyro.sample(
f"theta_{f}", dist.Categorical(Vindex(trans)[theta, :]))
theta_mask = Vindex(self.theta_matrix)[..., 0, theta]
m_mask = Vindex(self.m_matrix)[..., 0, theta]
with K_plate:
m = pyro.sample(f"m_{f}", dist.Categorical(Vindex(pi_m)[m_mask]))
height = pyro.sample(
f"height_{f}", dist.Gamma(
param("height_loc")[m] * param("height_beta")[m],
param("height_beta")[m]))
width = pyro.sample(
f"width_{f}", ScaledBeta(
param("width_mode"),
param("width_size"), 0.5, 2.5))
x = pyro.sample(
f"x_{f}", ScaledBeta(
0, self.size[theta_mask], -(data.D+1)/2, data.D+1))
y = pyro.sample(
f"y_{f}", ScaledBeta(
0, self.size[theta_mask], -(data.D+1)/2, data.D+1))
width = width * 2.5 + 0.5
x = x * (data.D+1) - (data.D+1)/2
y = y * (data.D+1) - (data.D+1)/2
# calculate the shape of the 2-D Gaussian spot based on sampled parameters
locs = data.loc(height, width, x, y, background, batch_idx, None, f)
pyro.sample(
f"data_{f}", self.CameraUnit(
locs, param("gain"), param("offset")).to_event(2),
obs=data[batch_idx, f])
thetas.append(theta)
return thetas
And here is the model written using DiscreteHMM
:
def discretehmm_model(self, data, prefix):
K_plate = pyro.plate("K_plate", self.K, dim=-2)
N_plate = pyro.plate("N_plate", data.N, dim=-1)
with N_plate as batch_idx:
background = pyro.sample(
"background", dist.Gamma(
param(f"{prefix}/background_loc")[batch_idx]
* param("background_beta"), param("background_beta")).expand([len(batch_idx), data.F]).to_event(1))
with K_plate:
pi_m = pi_m_calc(param("lamda"), self.S)
m_logits = Vindex(pi_m)[self.m_matrix].log()
h_dist = EnumDistribution(dist.Gamma(
param("height_loc") * param("height_beta"),
param("height_beta")), m_logits)
x_dist = ScaledBeta(
0, self.size[self.theta_matrix], -(data.D+1)/2, data.D+1)
y_dist = ScaledBeta(
0, self.size[self.theta_matrix], -(data.D+1)/2, data.D+1)
hxy_dist = StackDistributions(h_dist, x_dist, y_dist)
init = pi_theta_calc(param("pi"), self.K, self.S).log() # state_dim
trans = theta_trans_calc(param("A"), self.K, self.S).log() # state_dim, state_dim
hmm_dist = dist.DiscreteHMM(init, trans, hxy_dist, duration=data.F)
hxy = pyro.sample("hxy", hmm_dist)
height, x, y = torch.unbind(hxy, dim=-1)
width = pyro.sample(
"width", ScaledBeta(
param("width_mode"),
param("width_size"), 0.5, 2.5).expand([data.F]).to_event(1))
width = width * 2.5 + 0.5
x = x * (data.D+1) - (data.D+1)/2
y = y * (data.D+1) - (data.D+1)/2
# calculate the shape of the 2-D Gaussian spot based on sampled parameters
locs = data.loc(height, width, x, y, background, batch_idx)
pyro.sample(
"data", self.CameraUnit(
locs, param("gain"), param("offset")).to_event(3),
obs=data[batch_idx])
I fit the data using the second model and then use the first model to do inference:
guide_trace = poutine.trace(self.viterbi_guide).get_trace(self.data)
trained_model = poutine.replay(
poutine.enum(self.viterbi_model, first_available_dim=-4), trace=guide_trace)
thetas = infer_discrete(
trained_model, temperature=0, first_available_dim=-4)(data=self.data)
thetas = torch.stack(thetas, dim=-1)
self.predictions["z"] = (thetas > 0).cpu().data