Dynamical model, structured guide

This is a question that was started in another thread SVI vectorize particles, model parallelization.

I have dynamical state space model of the form:
State: 𝒙_𝒕=𝑭⋅ 𝒙_(π‘‘βˆ’1)+π’˜_𝑑, , π’˜_π‘‘βˆΌπ‘(0, 𝑄)
Measurement: 𝒛_𝑑=β„Ž(𝒙_𝑑 )+𝒖_𝑑 , 𝒖_π‘‘βˆΌπ‘(0, 𝑅_𝑑 )
Which I modeled in the following way:

𝒙_π‘‘βˆΌπ‘(𝑭⋅ 𝒙_(π‘‘βˆ’1),𝑄)       		 # state transition model
𝒛_π‘‘βˆΌπ‘(β„Ž( 𝒙_𝑑 ),𝑅_𝑑 )        	 # observation model
π‘π‘Ÿπ‘–π‘œπ‘Ÿ~𝑁(π‘₯_0, Ξ£_0) βˆˆβ„^|π‘₯_𝑑|

This is my implementation of the guide and model.

Model

def model(prior, F, Q, r, data):
    bound = dist_cam_to_origin-0.5 # Boundary box limit
    P0 = torch.eye(prior.shape[0])
    prev_state = pyro.sample("prior", dist.MultivariateNormal(prior, P0))
    
    for t in pyro.markov(range(0,len(data))):
        # Transition model formula
        state = pyro.sample("state_{}".format(t), dist.MultivariateNormal(prev_state, Q))
        state = torch.clamp(state, -bound, bound) # Boundary box camera limit
        
        # Nonlinearity function h(x_t)
        c_state = relative_position(state)
            
        # Occlusion
        R_t = r*torch.eye(4)
        occ_cam = torch.tensor([0,1,2])
        if t >= occ_start and t<=occ_start+occ_dur:
            R_t[occ_cam,occ_cam]=10000000000.  
        
        # Observation model formula
        pyro.sample("measurement_{}".format(t), dist.MultivariateNormal(c_state, R_t), obs = data[t])
        prev_state = F@state.T

Guide

def guide(prior, F, Q, r, data):
    bound = dist_cam_to_origin-0.5 # Boundary box limit
    P0 = torch.eye(prior.shape[0])
    prev_state = pyro.sample("prior", dist.MultivariateNormal(prior, P0))
    
    for t in pyro.markov(range(0,len(data))):
        # Mean and covariance parameters
        x_value = pyro.param("x_value_{}".format(t), Variable(prev_state, requires_grad=True))
        var_value = pyro.param("var_value_{}".format(t), Variable(Q, requires_grad=True), 
                               constraint=dist.constraints.lower_cholesky)
        
        # x_t = N(x_value, var_value)
        state =  pyro.sample("state_{}".format(t), dist.MultivariateNormal(x_value, scale_tril=var_value))
        state = torch.clamp(state, -bound, bound)
    
        c_state = relative_position(state)
        prev_state = F@state.T

This guide that I wrote seems to behave as a mean field guide and does not take the dependency structure on the previous time step into account. Since, the update step of x_t = F x_(t-1) in the guide only affects the parameters initialization but have no effect on further SVI steps.

How can I modify my guide so that it takes into account this time dependency of the previous time step? Is this possible in Pyro?

Hi, you might be interested in the β€œguide” section of the deep Markov model tutorial, which covers the construction of a structured, data-dependent guide for a similar class of models in detail.

At a high level, you are free to make your guide distributions’ parameters depend on each other however you want. You could make your state depend on the previous timestep by introducing a learnable function g (e.g. a neural network or linear map) to compute the parameters of q(x_t | x_t-1):

theta = pyro.param("theta")
for t in range(T):
    ...
    x_value, var_value = g(state, theta)
    state = pyro.sample("state_{}".format(t), MultivariateNormal(x_value, scale_tril=var_value)

@eb8680_2, thank you for your response.
I specifically know the dependence of the mean parameter on the previous time step which is x_t = F* x_t-1. So I could implement this directly as a linear map.
But I rather don’t want to specify the dependence of the covariance parameter directly and was hoping the inference procedure can do this automatically.
For example in the related Kalman filter you have the update equations of this matrix P = F P F.T + Q, and you also have an update step that depends on the observations. Is it possible to say that the covariance depends on the covariance of the previous time step without stating the specifics?
This is how my guide looks right now, and I think in this way it should be able to take into account the mean dependency.

def guide(prior, F, Q, r, data):
    bound = dist_cam_to_origin-0.5 # Boundary box limit
    P0 = torch.eye(prior.shape[0])
    prev_state = pyro.sample("prior", dist.MultivariateNormal(prior, P0))

    for t in pyro.markov(range(0,len(data))):
        # Mean and covariance parameter initialization
        prev_state = pyro.param("x_value_{}".format(t), Variable(prev_state, requires_grad=True))
        P = pyro.param("var_value_{}".format(t), Variable(Q, requires_grad=True), 
                               constraint=dist.constraints.lower_cholesky)
        
        # neural network or linear map describing the dependency
        x_value = F@prev_state.T
        var_value = P
        
        # x_t = N(x_value, var_value)
        state =  pyro.sample("state_{}".format(t), dist.MultivariateNormal(x_value, scale_tril=var_value))
        state = torch.clamp(state, -bound, bound)
        
        prev_state = state
        P = var_value

If you want the covariance parameter to depend on the previous state in your guide, you must express the dependence in some form, even if it is not exact. For example, you could have a neural network compute the covariance parameter from the previous state and the observations, similar to the DMM tutorial I linked to above.

You also seem a bit confused about the use of pyro.param. In particular, the value of prev_state computed in your guide is only going to be used for initialization, as you observed in your original post, and will not introduce a direct dependency between the previous state and the current state. Your intended behavior would require F rather than prev_state to be a learnable pyro.param. You can find more background in the SVI tutorials.

At the end I would like a variational distribution that is split up in a normal distribution for my state x_t for every time step. This normal distribution has a mean and a covariance. If I get these estimated mean and covariance parameters out of the inference procedure, I can use the covariance matrix to see how accurate my mean state predictions are. Hence, I want an estimate for every state as well as the uncertainty of that state. Therefore, I thought that the parameters of the variational distribution should be the mean and covariance of these normal distributions.
Furthermore, I see the transition matrix F as a fixed matrix and would rather keep this one fixed.