Loss under custom guide vs auto guide

Hello, I’m trying to fit a linear random-effects model using numpyro and have been comparing the autoguide vs my custom guide to assess inference stability. I am continually running into issues with the loss increasing, then decreasing under my custom guide, whereas the auto-guide does not exhibit these issues.

My guide should roughly resemble the auto-guide (mostly a meanfield with exception for a few separate parameters), so I’m a bit confused as how this can happen.

Here are losses at various epochs under my guide vs the auto-guide:

My-Guide
iter 0 - loss = 7015422.0
iter 25 - loss = 59737.24609375
iter 50 - loss = 36599.234375
iter 75 - loss = 25873.19921875
iter 100 - loss = 24866.625
iter 125 - loss = 7331.56005859375
iter 150 - loss = 25220.345703125

Auto-Guide
iter 0 - loss = 4951479.5
iter 25 - loss = 193485.234375
iter 50 - loss = 112910.3828125
iter 75 - loss = 85999.2421875
iter 100 - loss = 62530.828125
iter 125 - loss = 52520.109375
iter 150 - loss = 42750.94140625

Here are the model and guide definitions.

    def model(X_1: jnp.ndarray, W_1: jnp.ndarray, y_1: jnp.ndarray,
          X_2: jnp.ndarray, W_2: jnp.ndarray, y_2: jnp.ndarray = None) -> None:

        n_1, p = X_1.shape
        n_2, p = X_2.shape

        # coupling parameter
        s_0 = 1 
        s = numpyro.sample("s", dist.MultivariateNormal(0., s_0 * jnp.eye(2)))

        # prior variance
        sigma_b = numpyro.param("sigma_b", jnp.sqrt(0.3 / p), constrain=constraints.positive)

        # effect sizes;  
        with numpyro.plate("beta_i", p): 
            beta_1 = numpyro.sample("beta_1", dist.Normal(0., W_1 ** (s[0]/2.) * sigma_b))
            beta_2 = numpyro.sample("beta_2", dist.Normal(0., W_2 ** (s[1]/2.) * sigma_b))

        # environmental var
        sigma_e1 = numpyro.param("sigma_e1", 1., constrain=constraints.positive)
        sigma_e2 = numpyro.param("sigma_e2", 1., constrain=constraints.positive)

        # likelihood/data generation
        mu_1 = jnp.dot(X_1, beta_1)
        mu_2 = jnp.dot(X_2, beta_2)
        with numpyro.plate("data", n_1):
            numpyro.sample("y_1", dist.Normal(mu_1, sigma_e1), obs=y_1)
            numpyro.sample("y_2", dist.Normal(mu_2, sigma_e2), obs=y_2)

         return

    def guide(X_1: jnp.ndarray, W_1: jnp.ndarray, y_1: jnp.ndarray,
          X_2: jnp.ndarray, W_2: jnp.ndarray, y_2: jnp.ndarray = None) -> None:
        n_1, p = X_1.shape
        n_2, p = X_2.shape

        # mean / sd for parameter s
        s_loc = numpyro.param("s_loc", jnp.zeros(2))
        s_scale = numpyro.param("s_scale", jnp.eye(2), constrain=constraints.positive_definite)

        # approximate multivariate-normal posterior
        s = numpyro.sample("s", dist.MultivariateNormal(s_loc, s_scale))

        # posterior means for betas
        beta_1_loc = numpyro.param("beta_1_loc", jnp.zeros(p))
        beta_2_loc = numpyro.param("beta_2_loc", jnp.zeros(p))

        # posterior sd for betas
        beta_1_scale = jnp.exp(numpyro.param("beta_1_scale", jnp.ones(p) / n_1))
        beta_2_scale = jnp.exp(numpyro.param("beta_2_scale", jnp.ones(p) / n_2))

        with numpyro.plate("beta_i", p):
            # mean-field approximation that all variants are independent in normal posterior
            beta_1 = numpyro.sample("beta_1", dist.Normal(beta_1_loc, beta_1_scale))
            beta_2 = numpyro.sample("beta_2", dist.Normal(beta_2_loc, beta_2_scale))

        return

What could be going on here?

Hi @nmancuso, which auto guide did you use? It seems to me that your guide (which is mixed by MultivariateNormal and Normal) has smaller losses than the auto guide (which is either MultivariateNormal or Normal). Probably the difference is due to the initial values of parameters? What happens when you run svi for longer iterations?

Hey @fehiepsi! I’m using the AutoNormal guide for now. I’ve found that it consistently beats the custom guide if I let it run long enough, with no evidence of instability during descent.

Here is my custom guide with additional epochs

iter 0 - loss = 7015422.0
iter 25 - loss = 59734.81640625
iter 50 - loss = 36651.51171875
iter 75 - loss = 25873.265625
iter 100 - loss = 24868.421875
iter 125 - loss = 7331.23828125
iter 150 - loss = 25219.52734375
iter 175 - loss = 6238.70263671875
iter 200 - loss = 25542.681640625
iter 225 - loss = 24036.82421875
iter 250 - loss = 3790.46728515625
iter 275 - loss = 8852.462890625
iter 300 - loss = 10470.646484375
iter 325 - loss = 3920.295654296875
iter 350 - loss = 138452.65625
iter 375 - loss = 7504.5576171875
iter 400 - loss = 4880.4033203125
iter 425 - loss = 4663.68115234375
iter 450 - loss = 14527.2314453125
iter 475 - loss = 11435.541015625
iter 500 - loss = 5050.41455078125
iter 525 - loss = 11271.375
iter 550 - loss = 22102.892578125
iter 575 - loss = 10646.486328125
iter 600 - loss = 3702.282470703125
iter 625 - loss = 7120.96923828125
iter 650 - loss = 8955.662109375
iter 675 - loss = 9598.91796875
iter 700 - loss = 3306.529052734375
iter 725 - loss = 3833.687255859375
iter 750 - loss = 4029.088134765625
iter 775 - loss = 15140.5224609375
iter 800 - loss = 5323.39990234375
iter 825 - loss = 22200.90234375
iter 850 - loss = 5734.90576171875
iter 875 - loss = 3677.29345703125
iter 900 - loss = 13262.2646484375
iter 925 - loss = 26661.77734375
iter 950 - loss = 4308.6220703125
iter 975 - loss = 5413.45703125
iter 999 - loss = 6676.58251953125

Compared to the AutoNormal guide

iter 0 - loss = 4951479.5
iter 25 - loss = 193382.640625
iter 50 - loss = 112860.578125
iter 75 - loss = 85964.078125
iter 100 - loss = 62505.140625
iter 125 - loss = 52494.4453125
iter 150 - loss = 42733.53125
iter 175 - loss = 38474.56640625
iter 200 - loss = 31414.3671875
iter 225 - loss = 29698.0859375
iter 250 - loss = 24536.91015625
iter 275 - loss = 22804.712890625
iter 300 - loss = 19097.376953125
iter 325 - loss = 17620.29296875
iter 350 - loss = 16325.677734375
iter 375 - loss = 15006.259765625
iter 400 - loss = 13883.3876953125
iter 425 - loss = 13145.9462890625
iter 450 - loss = 11558.70703125
iter 475 - loss = 10784.4140625
iter 500 - loss = 10432.048828125
iter 525 - loss = 9308.5087890625
iter 550 - loss = 8606.6162109375
iter 575 - loss = 8530.6337890625
iter 600 - loss = 7931.83056640625
iter 625 - loss = 8172.78369140625
iter 650 - loss = 7117.6748046875
iter 675 - loss = 6904.50146484375
iter 700 - loss = 6212.75390625
iter 725 - loss = 6336.6611328125
iter 750 - loss = 5833.28369140625
iter 775 - loss = 6054.15185546875
iter 800 - loss = 5215.68896484375
iter 825 - loss = 5145.27294921875
iter 850 - loss = 5037.85009765625
iter 875 - loss = 4889.51513671875
iter 900 - loss = 4580.75927734375
iter 925 - loss = 4551.75146484375
iter 950 - loss = 4577.2822265625
iter 975 - loss = 4281.0517578125
iter 1000 - loss = 3967.41064453125
iter 1025 - loss = 4061.67822265625
iter 1050 - loss = 3798.4130859375
iter 1075 - loss = 3724.45068359375
iter 1100 - loss = 3573.056640625
iter 1125 - loss = 3504.5361328125
iter 1150 - loss = 3302.78564453125
iter 1175 - loss = 3309.3583984375
iter 1200 - loss = 3250.874267578125
iter 1225 - loss = 3246.2080078125
iter 1250 - loss = 3033.527587890625
iter 1275 - loss = 2967.08251953125
iter 1300 - loss = 2927.34130859375
iter 1325 - loss = 2877.117431640625
iter 1350 - loss = 2921.117431640625
iter 1375 - loss = 2755.047607421875
iter 1400 - loss = 2706.14501953125
iter 1425 - loss = 2647.163818359375
iter 1450 - loss = 2632.062744140625
iter 1475 - loss = 2618.9951171875
iter 1500 - loss = 2588.60302734375
iter 1525 - loss = 2578.894287109375
iter 1550 - loss = 2460.376953125
iter 1575 - loss = 2413.8603515625
iter 1600 - loss = 2340.7275390625
iter 1625 - loss = 2356.1015625
iter 1650 - loss = 2337.22216796875
iter 1675 - loss = 2267.146728515625
iter 1700 - loss = 2327.480712890625
iter 1725 - loss = 2257.56591796875
iter 1750 - loss = 2199.79345703125
iter 1775 - loss = 2177.412109375
iter 1800 - loss = 2138.787841796875
iter 1825 - loss = 2153.142578125
iter 1850 - loss = 2071.3427734375
iter 1875 - loss = 2066.14404296875
iter 1900 - loss = 2064.563720703125
iter 1925 - loss = 2015.911865234375
iter 1950 - loss = 1959.7615966796875
iter 1975 - loss = 1971.7305908203125
iter 2000 - loss = 1949.1373291015625
iter 2025 - loss = 1950.587158203125
iter 2050 - loss = 1933.5701904296875
iter 2075 - loss = 1923.6873779296875
iter 2100 - loss = 1887.3551025390625
iter 2125 - loss = 1887.7850341796875
iter 2150 - loss = 1851.0819091796875
iter 2175 - loss = 1878.77294921875
iter 2200 - loss = 1826.6134033203125
iter 2225 - loss = 1790.91748046875
iter 2250 - loss = 1810.6322021484375
iter 2275 - loss = 1803.600341796875
iter 2300 - loss = 1776.4381103515625
iter 2325 - loss = 1779.8831787109375
iter 2350 - loss = 1745.7138671875
iter 2375 - loss = 1728.948486328125
iter 2400 - loss = 1723.595947265625
iter 2425 - loss = 1719.404052734375
iter 2450 - loss = 1704.81494140625
iter 2475 - loss = 1718.2303466796875
iter 2500 - loss = 1691.90625
iter 2525 - loss = 1679.7371826171875
iter 2550 - loss = 1683.634033203125
iter 2575 - loss = 1677.84912109375
iter 2600 - loss = 1652.3157958984375
iter 2625 - loss = 1665.14892578125
iter 2650 - loss = 1637.4512939453125
iter 2675 - loss = 1634.08935546875
iter 2700 - loss = 1626.68505859375
iter 2725 - loss = 1621.2652587890625
iter 2750 - loss = 1619.2952880859375
iter 2775 - loss = 1602.82373046875
iter 2800 - loss = 1593.0880126953125
iter 2825 - loss = 1596.17431640625
iter 2850 - loss = 1585.215576171875
iter 2875 - loss = 1584.76025390625
iter 2900 - loss = 1582.0704345703125
iter 2925 - loss = 1570.028076171875
iter 2950 - loss = 1563.78564453125
iter 2975 - loss = 1555.1669921875
iter 3000 - loss = 1557.39501953125

Surely this can’t be the initialization differences alone? I’m initializing my weights to 0 and prior sd to be small, whereas the auto guide uses (as far as I can tell from the source) some bounded uniformly random number for weights and 0.1 for sd.

How about using Normal instead of MultivariateNormal in the guide? Your custom guide is just the same as AutoNormal guide (modulo initial values) after that change (as far as I can see).