Unlike time series example, `condition` truncates forecast

In the time series example, condition is used to condition the data for points at which we have observations and forecast at points where no data is available. However, if I use a batched model to do “time series” forecasting (i.e. no actual temporal dependence but I have exogenous regressors which vary over time. See this post.), the batch length is truncated to match the length of the data provided to condition. I had expected/hoped that I would get forecasts for any batched terms after the condition data. More, concretely:

def model1():
  normal = ny.sample("normal", dist.Normal(loc=np.array([-1, 1]), scale=1))
  print("normal", normal.shape)
  forecast = ny.deterministic("forecast", normal[-1, np.newaxis])
  print("forecast", forecast.shape)
def model2():
  normal_dist = dist.Normal(loc=np.array([-1, 1]), scale=1)   
  with ny.handlers.condition(data={"normal": np.array([3])}):
    normal = ny.sample("normal", normal_dist)
  print("normal", normal.shape)
def model3():
  normal_dist = dist.Normal(loc=np.array([-1, 1]), scale=1)   
  with ny.handlers.condition(data={"normal": np.array([3])}):
    normal = ny.sample("normal", normal_dist)
  print("normal", normal.shape)
  forecast = ny.sample("forecast", dist.Normal(loc=normal_dist.loc[-1], scale=normal_dist.scale))
  print("forecast", forecast.shape)

model1 prints out normal (2,); forecast (1,) as expected. But adding condition in model2 truncates normal to normal (1,).

Is the best solution to manually “reconstruct” a forecast as in model3?

Full notebook example can be found here.

This is not supported. For forecasting, we recommend splitting out observed statement and forecast statement by some if/else logic. Like

def model4():
  normal_dist = dist.Normal(loc=np.array([-1, 1]), scale=1)   
  with ny.handlers.condition(data={"normal": np.array([3])}):
    normal = ny.sample("normal", dist.Normal(normal_dist.loc[:1], normal_dist.scale)
  print("normal", normal.shape)
  forecast = ny.sample("forecast", dist.Normal(loc=normal_dist.loc[-1], scale=normal_dist.scale))
  print("forecast", forecast.shape)

Model 3 will give incorrect inference results because the observed data will be broadcasted to normal_dist.shape() when evaluating the log likelihood.