Hi everyone,
I have recently started again with pyro and I am facing some troubles with timeseries and observation masks.
Here’s the code (which works fine) without masking observations out:
# Fake data just to test model
L = 15 #length of time series
N = 20 #number of time series (batch size)
T_dist = dist.Normal(20*torch.ones(L,N),10)
T = T_dist.sample() #input of time series
#function that generates time series output (R) from the time series input (T)
def generateR(T):
R = torch.empty(T.shape[0],T.shape[1])
T_max = 20
R[0] = ((T[0]/T_max) + dist.Normal(torch.zeros(N),0.1).sample()).clip(0,1)
for i in range(1,T.shape[0]):
R[i] = R[i-1] + ((T[i]/T_max) + dist.Normal(torch.zeros(N),0.1).sample()).clip(0,1)
return R
R = generateR(T) #time series output (observed values)
def model(T,R):
t_max = pyro.sample('t_max',dist.Normal(15,10))
sigma = pyro.sample('sigma',dist.Exponential(5))
prev_out = (T[0]/t_max).clip(0,1) #initialization for t = 0
for t in range(1,T.shape[0]):
current_loc = prev_out + (T[t]/t_max).clip(0,1)
out_t = pyro.sample("obs_x_%d" % t, dist.Normal(current_loc, sigma).to_event(1), obs=R[t])
prev_out = out_t
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=300, warmup_steps=100)
mcmc.run(T,R)
Basically R is obtained as follows:
R[i] = R[i-1] + ((T/T_max) + Normal(0,1)).clip(0,1)
where T_max is the parameter I want to learn and T is the time series input to the model.
As long as I run it this way the mcmc runs smoothly and gives the correct results.
(correct estimation of sigma and t_max).
What I tried to do is to simulate missing data in R(imagine I can’t observe some values of R).
What I did is the following:
mask = torch.randint(2,(L,N),dtype = torch.bool) #generate a random mask
And then the model became:
def model(T,R,mask):
t_max = pyro.sample('t_max',dist.Normal(15,10))
sigma = pyro.sample('sigma',dist.Exponential(5))
prev_out = (T[0]/t_max).clip(0,1)
#with pyro.plate("plate",T.shape[1]):
for t in range(1,T.shape[0]):
current_loc = prev_out + (T[t]/t_max).clip(0,1)
out_t = pyro.sample("obs_x_%d" % t, dist.Normal(current_loc, sigma).to_event(1), obs=R[t],obs_mask = mask[t])
prev_out = out_t
Basically I only changed the following line to mask out some observation:
out_t = pyro.sample("obs_x_%d" % t, dist.Normal(current_loc, sigma).to_event(1), obs=R[t],obs_mask = mask[t])
Now the problem is that the training is incredibly slow (100x slower) even if I provide a mask full of True values (which should be identical to the previous setup).
Am I missing something?
Thanks a lot for anyone who will help.