Set unique constraints for each value in a plate


I am working with a large dataset consisting of sales data. For simplicity, the data I am focused on includes a store ID, the state, and a cluster ID. My guide looks like this:

def guide(storeID, clusterID, state, volume_sales_obs=None):
    ##parameters for global store intercept distribution
    μ_α = numpyro.sample("μ_α", dist.Normal(loc = numpyro.param("loc_μ_α",0.),
                                            scale = numpyro.param("scale_μ_α",1., constraint=dist.constraints.positive)))
    σ_α = numpyro.sample("σ_α", dist.HalfNormal(numpyro.param("scale_σ_α",1.)))
    unique_stores= np.unique(storeID)                                        
    n_stores = len(unique_stores)

    ##parameters for individual store intercept distributions
    with numpyro.plate("plate_store", n_stores ):
        α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
    μ_clusters = numpyro.sample("μ_clusters", dist.Normal(loc = numpyro.param("loc_μ_clusters",0.),
                                            scale = numpyro.param("scale_μ_clusters",1., constraint=dist.constraints.positive)))
    σ_clusters = numpyro.sample("σ_clusters", dist.HalfNormal(numpyro.param("scale_σ_clusters",1.)))

    ##parameters for individual state/cluster intercept distributions
    with numpyro.plate("plate_cluster", 12):
        cluster_mus = numpyro.sample("cluster_mus", dist.Normal(loc = numpyro.param("loc_cluster_mus", μ_clusters, constraint=dist.constraints.interval(-1, 1)),
                                                                scale = numpyro.param("scale_cluster_mus",σ_clusters, constraint=dist.constraints.positive)))
        cluster_sigmas = numpyro.sample("cluster_sigmas", dist.HalfNormal(numpyro.param("scale_cluster_sigmas", 0.2, constraint=dist.constraints.positive)))
        with numpyro.plate("plate_state", 51):
          cluster_state_coefs = numpyro.sample("cluster_state_coefs", dist.Normal(loc = numpyro.param("loc_cluster_state_coefs", cluster_mus, constraint=dist.constraints.interval(-1, 1)),
                                                                scale = numpyro.param("scale_cluster_state_coefs",cluster_sigmas, constraint=dist.constraints.positive)))

Right now I am applying the same constraints to every store/cluster combination. In this case, that the coef falls between -1 and 1. Is there a way to pass unique constraint intervals for each state/cluster combination? I could write up each coefficient separately, but I am concerned this would decrease processing speed greatly.


@tjm I think what you are looking for is cat/stack constraints/transforms in this PR. Could you make a similar FR for numpyro? In case your constraints have the same type, says interval transforms, you can set the joint constraint to


Thanks, that is exactly what I needed.