Hi,
I am working with a large dataset consisting of sales data. For simplicity, the data I am focused on includes a store ID, the state, and a cluster ID. My guide looks like this:
def guide(storeID, clusterID, state, volume_sales_obs=None):
##parameters for global store intercept distribution
μ_α = numpyro.sample("μ_α", dist.Normal(loc = numpyro.param("loc_μ_α",0.),
scale = numpyro.param("scale_μ_α",1., constraint=dist.constraints.positive)))
σ_α = numpyro.sample("σ_α", dist.HalfNormal(numpyro.param("scale_σ_α",1.)))
unique_stores= np.unique(storeID)
n_stores = len(unique_stores)
##parameters for individual store intercept distributions
with numpyro.plate("plate_store", n_stores ):
α = numpyro.sample("α", dist.Normal(μ_α, σ_α))
μ_clusters = numpyro.sample("μ_clusters", dist.Normal(loc = numpyro.param("loc_μ_clusters",0.),
scale = numpyro.param("scale_μ_clusters",1., constraint=dist.constraints.positive)))
σ_clusters = numpyro.sample("σ_clusters", dist.HalfNormal(numpyro.param("scale_σ_clusters",1.)))
##parameters for individual state/cluster intercept distributions
with numpyro.plate("plate_cluster", 12):
cluster_mus = numpyro.sample("cluster_mus", dist.Normal(loc = numpyro.param("loc_cluster_mus", μ_clusters, constraint=dist.constraints.interval(-1, 1)),
scale = numpyro.param("scale_cluster_mus",σ_clusters, constraint=dist.constraints.positive)))
cluster_sigmas = numpyro.sample("cluster_sigmas", dist.HalfNormal(numpyro.param("scale_cluster_sigmas", 0.2, constraint=dist.constraints.positive)))
with numpyro.plate("plate_state", 51):
cluster_state_coefs = numpyro.sample("cluster_state_coefs", dist.Normal(loc = numpyro.param("loc_cluster_state_coefs", cluster_mus, constraint=dist.constraints.interval(-1, 1)),
scale = numpyro.param("scale_cluster_state_coefs",cluster_sigmas, constraint=dist.constraints.positive)))
Right now I am applying the same constraints to every store/cluster combination. In this case, that the coef falls between -1 and 1. Is there a way to pass unique constraint intervals for each state/cluster combination? I could write up each coefficient separately, but I am concerned this would decrease processing speed greatly.
Thanks!