MLE failing with loss increasing

I was trying to estimate the parameters of a Skellam distribution. This distribution is available in Tensorflow probability. At first I tried MCMC, but having failed with that I wanted to try MLE as a simpler approach.

When I run the code below, optimiser just decreases the parameters with seemingly no regard for the increasing loss. I tried the same with a simple Poisson distribution (also imported from TF probability), and everything worked fine. I wonder if I’m just missing some silly thing here, or if there’s a problem with Skellam distribution itself (since Poisson didn’t face such difficulties)?

import numpy as np
import numpyro
numpyro.set_host_device_count(4)
from numpyro import distributions as dists
from numpyro.infer import SVI, Trace_ELBO
from tensorflow_probability.substrates.jax import distributions as tfpd
from jax import random


def generate_skellam_data(rate1, rate2, size):
    return (
        np.random.poisson(rate1, size=size) - np.random.poisson(rate2, size=size)
    )


def skellam_model(y):
    rate1 = numpyro.param("rate1", 0.5)
    rate2 = numpyro.param("rate2", 0.5)
    with numpyro.plate("obs_plate", len(y)):
        numpyro.sample(
            "obs",
            tfpd.Skellam(rate1=rate1, rate2=rate2),
            obs=y
        )


def mle_guide(*args, **kvargs):
    pass


skellam_y = generate_skellam_data(2.6, 2.3, 1000)
adam = numpyro.optim.Adam(step_size=0.001)
svi = SVI(skellam_model, mle_guide, adam, loss=Trace_ELBO())
rng_key = random.PRNGKey(0)
svi_state = svi.init(rng_key, skellam_y)
for step in range(1050):
    svi_state, loss = svi.update(svi_state, skellam_y)
    if step % 10 == 0:
        print('[iter {}] loss: {:.4f}'.format(step, loss))

presumably you need to put positivity constraints or similar on rate*

The parameter values are not close to the edges of the actual parameter space, and the gradient should guide them away from the edges, so I thought the constraints wouldn’t enter into it?

Here’s a sample run with the print statement modified slightly to show how parameters evolve (interesting though that yesterday I was sure both of the parameters were consistently headed towards zero right from the start, now the other one increases? The effect is still roughly the same in that after a few iterations the loss starts to increase).

[iter   0] loss: 2924.69. rate1: 0.50, rate2: 0.50
[iter  10] loss: 2919.43. rate1: 0.51, rate2: 0.49
[iter  20] loss: 2915.22. rate1: 0.52, rate2: 0.48
[iter  30] loss: 2912.29. rate1: 0.53, rate2: 0.47
[iter  40] loss: 2910.74. rate1: 0.54, rate2: 0.46
[iter  50] loss: 2910.63. rate1: 0.55, rate2: 0.45
[iter  60] loss: 2911.96. rate1: 0.55, rate2: 0.44
[iter  70] loss: 2914.67. rate1: 0.56, rate2: 0.43
[iter  80] loss: 2918.71. rate1: 0.56, rate2: 0.42
[iter  90] loss: 2923.99. rate1: 0.56, rate2: 0.41
[iter 100] loss: 2930.44. rate1: 0.57, rate2: 0.40
[iter 110] loss: 2937.95. rate1: 0.57, rate2: 0.40
[iter 120] loss: 2946.42. rate1: 0.57, rate2: 0.39
[iter 130] loss: 2955.75. rate1: 0.57, rate2: 0.38
[iter 140] loss: 2965.82. rate1: 0.57, rate2: 0.37
[iter 150] loss: 2976.53. rate1: 0.57, rate2: 0.36
[iter 160] loss: 2987.79. rate1: 0.56, rate2: 0.36
[iter 170] loss: 2999.47. rate1: 0.56, rate2: 0.35
[iter 180] loss: 3011.50. rate1: 0.56, rate2: 0.34
[iter 190] loss: 3023.77. rate1: 0.56, rate2: 0.33
[iter 200] loss: 3036.20. rate1: 0.55, rate2: 0.33

I did try it out with positivity constraints imposed on the parameters, but got similar results

[iter   0] loss: 2893.92. rate1: 0.50, rate2: 0.50
[iter  10] loss: 2890.73. rate1: 0.51, rate2: 0.49
[iter  20] loss: 2887.67. rate1: 0.51, rate2: 0.49
[iter  30] loss: 2884.79. rate1: 0.52, rate2: 0.48
[iter  40] loss: 2882.10. rate1: 0.52, rate2: 0.48
[iter  50] loss: 2879.63. rate1: 0.53, rate2: 0.48
[iter  60] loss: 2877.36. rate1: 0.53, rate2: 0.47
[iter  70] loss: 2875.32. rate1: 0.53, rate2: 0.47
[iter  80] loss: 2873.50. rate1: 0.54, rate2: 0.46
[iter  90] loss: 2871.90. rate1: 0.54, rate2: 0.46
[iter 100] loss: 2870.52. rate1: 0.55, rate2: 0.45
[iter 110] loss: 2869.37. rate1: 0.55, rate2: 0.45
[iter 120] loss: 2868.44. rate1: 0.56, rate2: 0.45
[iter 130] loss: 2867.73. rate1: 0.56, rate2: 0.44
[iter 140] loss: 2867.25. rate1: 0.56, rate2: 0.44
[iter 150] loss: 2866.98. rate1: 0.57, rate2: 0.43
[iter 160] loss: 2866.93. rate1: 0.57, rate2: 0.43
[iter 170] loss: 2867.09. rate1: 0.57, rate2: 0.43
[iter 180] loss: 2867.46. rate1: 0.58, rate2: 0.42
[iter 190] loss: 2868.04. rate1: 0.58, rate2: 0.42
[iter 200] loss: 2868.82. rate1: 0.58, rate2: 0.42
[iter 210] loss: 2869.80. rate1: 0.58, rate2: 0.41
[iter 220] loss: 2870.97. rate1: 0.58, rate2: 0.41
[iter 230] loss: 2872.32. rate1: 0.59, rate2: 0.41
[iter 240] loss: 2873.85. rate1: 0.59, rate2: 0.40
[iter 250] loss: 2875.55. rate1: 0.59, rate2: 0.40
[iter 260] loss: 2877.41. rate1: 0.59, rate2: 0.40
[iter 270] loss: 2879.43. rate1: 0.59, rate2: 0.39
[iter 280] loss: 2881.59. rate1: 0.59, rate2: 0.39
[iter 290] loss: 2883.89. rate1: 0.59, rate2: 0.39
[iter 300] loss: 2886.31. rate1: 0.59, rate2: 0.39
[iter 310] loss: 2888.86. rate1: 0.59, rate2: 0.38
[iter 320] loss: 2891.52. rate1: 0.59, rate2: 0.38
[iter 330] loss: 2894.28. rate1: 0.60, rate2: 0.38
[iter 340] loss: 2897.13. rate1: 0.60, rate2: 0.37
[iter 350] loss: 2900.06. rate1: 0.60, rate2: 0.37
[iter 360] loss: 2903.08. rate1: 0.59, rate2: 0.37
[iter 370] loss: 2906.16. rate1: 0.59, rate2: 0.37
[iter 380] loss: 2909.30. rate1: 0.59, rate2: 0.36
[iter 390] loss: 2912.49. rate1: 0.59, rate2: 0.36
[iter 400] loss: 2915.73. rate1: 0.59, rate2: 0.36
[iter 410] loss: 2919.01. rate1: 0.59, rate2: 0.36

I do now wonder if the Skellam distribution’s likelihood function is just very difficult to optimize, although then I would rather expect optimisation to converge on a local minimum and not end up completely off the mark.