Observation Masked

Hi everyone,
I have recently started again with pyro and I am facing some troubles with timeseries and observation masks.
Here’s the code (which works fine) without masking observations out:

# Fake data just to test model
L = 15 #length of time series 
N = 20 #number of time series (batch size)

T_dist = dist.Normal(20*torch.ones(L,N),10)
T = T_dist.sample() #input of time series

#function that generates time series output (R) from the time series input (T)
def generateR(T):  
    R = torch.empty(T.shape[0],T.shape[1])
    T_max = 20
    R[0] = ((T[0]/T_max) + dist.Normal(torch.zeros(N),0.1).sample()).clip(0,1)
    
    for i in range(1,T.shape[0]):
        R[i] = R[i-1] + ((T[i]/T_max) + dist.Normal(torch.zeros(N),0.1).sample()).clip(0,1)
    return R

R = generateR(T)  #time series output (observed values)

def model(T,R):
    t_max = pyro.sample('t_max',dist.Normal(15,10))
    
    sigma = pyro.sample('sigma',dist.Exponential(5))
    
    prev_out = (T[0]/t_max).clip(0,1) #initialization for t = 0

    for t in range(1,T.shape[0]):
        current_loc = prev_out + (T[t]/t_max).clip(0,1)
        out_t = pyro.sample("obs_x_%d" % t, dist.Normal(current_loc, sigma).to_event(1), obs=R[t])
        prev_out = out_t

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=300, warmup_steps=100)
mcmc.run(T,R)

Basically R is obtained as follows:
R[i] = R[i-1] + ((T/T_max) + Normal(0,1)).clip(0,1)
where T_max is the parameter I want to learn and T is the time series input to the model.

As long as I run it this way the mcmc runs smoothly and gives the correct results.
(correct estimation of sigma and t_max).
What I tried to do is to simulate missing data in R(imagine I can’t observe some values of R).
What I did is the following:

mask = torch.randint(2,(L,N),dtype = torch.bool) #generate a random mask

And then the model became:

def model(T,R,mask):
    t_max = pyro.sample('t_max',dist.Normal(15,10))
    
    sigma = pyro.sample('sigma',dist.Exponential(5))
    
    prev_out = (T[0]/t_max).clip(0,1)

    #with pyro.plate("plate",T.shape[1]):
    for t in range(1,T.shape[0]):
        current_loc = prev_out + (T[t]/t_max).clip(0,1)
        out_t = pyro.sample("obs_x_%d" % t, dist.Normal(current_loc, sigma).to_event(1), obs=R[t],obs_mask = mask[t])
        prev_out = out_t

Basically I only changed the following line to mask out some observation:

out_t = pyro.sample("obs_x_%d" % t, dist.Normal(current_loc, sigma).to_event(1), obs=R[t],obs_mask = mask[t])

Now the problem is that the training is incredibly slow (100x slower) even if I provide a mask full of True values (which should be identical to the previous setup).
Am I missing something?

Thanks a lot for anyone who will help.

If each of your masked values are either all-observed or all-unobserved, you could avoid masking and use None as in

out_t = pyro.sample(
    "obs_x_%d" % t,
    dist.Normal(current_loc, sigma).to_event(1),
    obs=R[t] if mask[t] else None,
)

Otherwise, yes, masking elementwise is slow. To speed things up you could alternatively simulate a mask by inflating some sigmas:

sigma_t = sigma.clone()
sigma_t[mask[t]] = 1e3  # or some large value
out_t = pyro.sample(
    "obs_x_%d" % t,
    dist.Normal(current_loc, sigma_t).to_event(1),
    obs=R[t],
)

and make sure the unobserved R[t] are zeros or some safe value (not NAN).

Thanks a lot, I was actually thinking about interpolating the missing data for providing a “safe” value (since it’s a time series) and inflating the sigma as u suggested. Thanks for your kind answer.