This is a question that was started in another thread SVI vectorize particles, model parallelization.
I have dynamical state space model of the form:
State: π_π=πβ
π_(π‘β1)+π_π‘, , π_π‘βΌπ(0, π)
Measurement: π_π‘=β(π_π‘ )+π_π‘ , π_π‘βΌπ(0, π
_π‘ )
Which I modeled in the following way:
π_π‘βΌπ(πβ
π_(π‘β1),π) # state transition model
π_π‘βΌπ(β( π_π‘ ),π
_π‘ ) # observation model
πππππ~π(π₯_0, Ξ£_0) ββ^|π₯_π‘|
This is my implementation of the guide and model.
Model
def model(prior, F, Q, r, data):
bound = dist_cam_to_origin-0.5 # Boundary box limit
P0 = torch.eye(prior.shape[0])
prev_state = pyro.sample("prior", dist.MultivariateNormal(prior, P0))
for t in pyro.markov(range(0,len(data))):
# Transition model formula
state = pyro.sample("state_{}".format(t), dist.MultivariateNormal(prev_state, Q))
state = torch.clamp(state, -bound, bound) # Boundary box camera limit
# Nonlinearity function h(x_t)
c_state = relative_position(state)
# Occlusion
R_t = r*torch.eye(4)
occ_cam = torch.tensor([0,1,2])
if t >= occ_start and t<=occ_start+occ_dur:
R_t[occ_cam,occ_cam]=10000000000.
# Observation model formula
pyro.sample("measurement_{}".format(t), dist.MultivariateNormal(c_state, R_t), obs = data[t])
prev_state = F@state.T
Guide
def guide(prior, F, Q, r, data):
bound = dist_cam_to_origin-0.5 # Boundary box limit
P0 = torch.eye(prior.shape[0])
prev_state = pyro.sample("prior", dist.MultivariateNormal(prior, P0))
for t in pyro.markov(range(0,len(data))):
# Mean and covariance parameters
x_value = pyro.param("x_value_{}".format(t), Variable(prev_state, requires_grad=True))
var_value = pyro.param("var_value_{}".format(t), Variable(Q, requires_grad=True),
constraint=dist.constraints.lower_cholesky)
# x_t = N(x_value, var_value)
state = pyro.sample("state_{}".format(t), dist.MultivariateNormal(x_value, scale_tril=var_value))
state = torch.clamp(state, -bound, bound)
c_state = relative_position(state)
prev_state = F@state.T
This guide that I wrote seems to behave as a mean field guide and does not take the dependency structure on the previous time step into account. Since, the update step of x_t = F x_(t-1) in the guide only affects the parameters initialization but have no effect on further SVI steps.
How can I modify my guide so that it takes into account this time dependency of the previous time step? Is this possible in Pyro?