Custom guide outperformed by automatic guide in mixture model

I’ve built a small mixture of betas that I have (approximately) working with both a custom and AutoNormal guides. The custom guide does much worse than the auto guide (higher variance in the clustering parameters, though the mean is pretty similar; worse overall loss).

I assume that either my guide is doing something quite silly, the autoguide has some sort of transformations happening beneath the hood that make the geometry of the space nicer or both. I’ve looked through the source code of the autoguides bit, but I’m a bit stumped on how to use the transforms myself in my guide. I’ve also gone through the SVI tutorials, but I’m clearly missing something - if anyone has input on what AutoNormal is doing so clever and how to leverage it for myself or how I’m going so astray in my own guide, that’d be much appreciated.

Model

@config_enumerate
def betaMixtureModel(data, K):
    N, P = data.shape

    # GLOBAL VARIABLES
    # Variable for cluster weights
    # Concentration parameter to component weights
    a, b = [3.0, 0.333]
    concentration = numpyro.sample('concentration', dist.Gamma(a, b))

    # Component weights
    weights = numpyro.sample(
        "weights", 
        dist.Dirichlet(jnp.ones(K) * concentration / K)
    )

    # The hyperparameters for the cluster/component parameters
    c, d = [1.0, 0.25]
    alpha_shape = numpyro.sample('alpha_shape', dist.Gamma(c, d))

    e, f = [1.0, 0.25]
    alpha_rate = numpyro.sample('alpha_rate', dist.Gamma(e, f))
    
    g, h = [1.0, 0.25]
    beta_shape = numpyro.sample('beta_shape', dist.Gamma(g, h))

    i, j = [1.0, 0.25]
    beta_rate = numpyro.sample('beta_rate', dist.Gamma(i, j))

    with numpyro.plate("components", K, dim=-2) as k:
        with numpyro.plate("measurements", P, dim=-1) as p:
            alpha = numpyro.sample('alpha', dist.Gamma(alpha_shape, alpha_rate))
            beta = numpyro.sample('beta', dist.Gamma(beta_shape, beta_rate))

    with numpyro.plate("data", N, dim=-2) as n:
        # Local variables.
        assignment = numpyro.sample("assignment", dist.Categorical(weights))
        with numpyro.plate("measurements", P, dim=-1) as p:
            numpyro.sample(
                "obs", 
                dist.Beta(
                    Vindex(alpha)[..., assignment, p], 
                    Vindex(beta)[..., assignment, p]
                ), 
                obs=data
            )

Custom guide

def betaMixtureGuide(data, K):
    N, P = data.shape

    # Variational parameters
    ## For concentration
    aq = numpyro.param("aq", jnp.array(100.0), constraint=constraints.positive)
    bq = numpyro.param("bq", jnp.array(50.0), constraint=constraints.positive)

    ## For beta hyperparameters
    cq = numpyro.param("cq", jnp.array(100.0), constraint=constraints.positive)
    dq = numpyro.param("dq", jnp.array(20.0), constraint=constraints.positive)

    eq = numpyro.param("eq", jnp.array(50.0), constraint=constraints.positive)
    fq = numpyro.param("fq", jnp.array(50.0), constraint=constraints.positive)

    gq = numpyro.param("gq", jnp.array(60.0), constraint=constraints.positive)
    hq = numpyro.param("hq", jnp.array(20.0), constraint=constraints.positive)

    iq = numpyro.param("iq", jnp.array(60.0), constraint=constraints.positive)
    jq = numpyro.param("jq", jnp.array(20.0), constraint=constraints.positive)

    # Sample global parameters
    concentration = numpyro.sample("concentration", dist.Gamma(aq, bq))
    weights = numpyro.sample(
        "weights", 
        dist.Dirichlet(jnp.ones(K) * concentration / K)
    )

    # The hyperparameters for the cluster/component parameters
    alpha_shape = numpyro.sample(
        'alpha_shape', 
        dist.Gamma(cq, dq)
    )
    alpha_rate = numpyro.sample(
        'alpha_rate', 
        dist.Gamma(eq, fq)
    )
    
    beta_shape = numpyro.sample(
        'beta_shape', 
        dist.Gamma(gq, hq)
    )
    beta_rate = numpyro.sample(
        'beta_rate', 
        dist.Gamma(iq, jq)
    )

    with numpyro.plate("components", K, dim=-2) as k:
        with numpyro.plate("measurements", P, dim=-1) as p:
            alpha = numpyro.sample(
                'alpha', 
                dist.Gamma(alpha_shape, alpha_rate)
            )
            beta = numpyro.sample(
                'beta', 
                dist.Gamma(beta_shape, beta_rate)
            )

Autoguide call


def betaMixtureGuide(data, K, seed=1024):
    N, P = data.shape    
    global_model = numpyro.handlers.block(
        numpyro.handlers.seed(betaMixtureModel, PRNGKey(seed)),
        hide_fn=lambda site: site["name"]
        not in ["concentration", 
                "weights", 
                "alpha_shape",
                "alpha_rate",
                "beta_shape",
                "beta_rate",
                "components",
                "measurements",
                "alpha",
                "beta"]
    )
    init_vals = {
        "concentration": 1.0,
        "weights": jnp.ones(K) / K,
        "alpha_shape": 15.0,
        "alpha_rate": 5.0,
        "beta_shape": 15.0,
        "beta_rate": 5.0,
        "alpha": dist.Gamma(
            jnp.ones((K, P))*5.0, 
            jnp.ones((K, P))*0.5
        ).sample(PRNGKey(seed)),
        "beta": dist.Gamma(
            jnp.ones((K, P))*5.0, 
            jnp.ones((K, P))*0.5
        ).sample(PRNGKey(seed + 1))
    }

    guide = AutoNormal(
        global_model,
        init_loc_fn=init_to_value(values=init_vals)
    )
    return guide