Tailoring the Deep Markov Model Code for Regression

Hi, commendable job on the code and the corresponding blog.
I am trying to tailor this code wherein instead of an array of length 88, my time series consists of one continuous variable i.e. cost.

There are a few parts that I do not understand entirely:

  1. Since mine is a regression problem; I know Bernoulli wouldn’t work. Which function should I rather try? And will the negative log-likelihood calculation still work? Could you point out what all modifications would be necessary?
  2. All the sequences vary in length; this part is resolved by imputing zeros until T_max. Does this mean that we are giving the information to the algorithm that from time t to T_max observed cost was zero (for my setup)? Or is this being taken care by the mask?
1 Like
  1. For regression problem, how about use continuous distribution such as Normal? I guess there is no modification for this change except the last sigmoid layer of Emitter (you can replace it by a layer which is suitable for your problem).

  2. For varying length, I think it is took care by mini_batch_mask.

1 Like

Thanks for the reply.

I am trying Normal distribution and for the Emitter I used a simple torch.nn.Linear(). On running the script everything seems to run except

I get an error: (Encountered NAN log_prob_sum)

Could you paste the complete error trace? Also, could you try lowering the optimizer step size and see if that makes it more stable?

Output for modified dmm script for regression purpose:

N_train_data: 159 avg. training seq. length: 13.55 N_mini_batches: 8
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/torch/nn/modules/rnn.py:38: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=1 and num_layers=1
“num_layers={}”.format(dropout, num_layers))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_1’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_2’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_3’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_4’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_5’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_6’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_7’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_8’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_9’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_10’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_11’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_12’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_13’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_14’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_15’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_16’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_18’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_19’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_21’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_25’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_26’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/infer/trace_elbo.py:143: UserWarning: Encountered NAN loss
warnings.warn(‘Encountered NAN loss’)
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_17’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_20’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_22’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_23’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_24’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_27’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_28’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_29’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_30’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_31’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_32’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_33’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_34’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_35’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_36’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_37’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_38’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_39’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_40’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_41’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_42’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_43’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_44’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_45’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_46’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_47’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_48’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_49’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_50’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_51’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_52’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
/home/themountaindog/anaconda3/envs/pymcenv/lib/python3.6/site-packages/pyro/poutine/trace_struct.py:17: UserWarning: Encountered NAN log_prob_sum at site ‘z_53’
warnings.warn(“Encountered NAN log_prob_sum at site ‘{}’”.format(name))
[training epoch 0000] nan (dt = 2.851 sec)
[training epoch 0001] nan (dt = 2.850 sec)
[training epoch 0002] nan (dt = 2.534 sec)
[training epoch 0003] nan (dt = 3.024 sec)
[training epoch 0004] nan (dt = 3.212 sec)
[training epoch 0005] nan (dt = 2.975 sec)
[training epoch 0006] nan (dt = 3.074 sec)

and this goes on…
I tried reducing the learning rate in the optimizer but that also throws the same output.

A few things to notice:
This warnings are not continuous from z_1 to z_53. If you see closely warnings for 25th and 26th position are not generated. Next if you scroll from top to down after z_26 warning; elbo_trace gives a warning and then things start from z_17.

Hi, any leads on this?

you need to make sure you’re passing valid parameters to distributions. for example sigma of a normal distribution is positive. also, adjusting the learning rate may not be enough. stochastic optimization of complex objectives is not automatic. the hyperparameters really do matter. knobs must be turned.