Pyro community, I am working on a multitask learning problem for classification. I would like to use SVI to learn the model parameters. I am posting it here since the only reference I could find on Dirichlet process was short of answering my questions. I have coded the model part of SVI and have some questions. Below are the details,
Overview
The data generation process is as per the graphical model below.
At the core I have a classification model,
P(YX)
, which uses weights, W
, to classify. For now we can assume that this is a binary classification problem e.g., logistic regression. The second part is that there are M
(tasks) of these classifiers. Apriori I don’t know the groupings of these M tasks so I would like to infer the groupings as well. A natural parametrization to learn groupings is to put a Dirichlet process prior (parametrized by \pi_k
and \alpha
) and allow the data to inform us on the number of groups among the M
tasks. Here c_m
is a categorical variable which is a onehot encoding of the task groupings.
Pyro Code
Here is code for model part of SVI,
def model(X, Y, tasks):
N, p = X.size()
K = 100
M = len(tasks.unique())
alpha0 = 1.0
crp_tables = []
crp_pi = []
priors = []
# Simple one layer NN with sigmoid activation i.e. logistic regression
nn_model = neural_network_model(p)
for k in range(K):
crp_pi = torch.Tensor(crp_tables + [alpha0])
crp_pi /= crp_pi.sum()
assigned_table = pyro.sample("assigned_table_{}".format(k), dist.Categorical(crp_pi))
if len(crp_tables) == 0 or assigned_table.item() > len(crp_tables)  1:
crp_tables.append(1.0)
else:
crp_tables[assigned_table.item()] += 1.0
# Create unit normal priors over the parameters
loc = torch.zeros(1, p)
scale = 2 * torch.ones(1, p)
w_prior = dist.Normal(loc, scale).independent(1)
bias_loc = torch.zeros(1)
bias_scale = 2 * torch.ones(1)
b_prior = dist.Normal(bias_loc, bias_scale).independent(1)
priors.append({"linear.weight_{}".format(k): w_prior, "linear.bias_{}".format(k): b_prior})
for m in range(M):
c_m = pyro.sample("c_{}".format(m), dist.Categorical(crp_pi))
# lift module parameters to random variables sampled from the priors
lifted_module = pyro.random_module("module_{}".format(m), nn_model, priors[c_m.item()])
# sample a neural network (which also samples w and b)
lifted_reg_model = lifted_module()
prediction_mean = lifted_reg_model(X[torch.ByteTensor(tasks == m), :].float()).squeeze(1)
pyro.sample("obs_{}".format(m), dist.Bernoulli(prediction_mean), obs=Y[torch.ByteTensor(tasks == m)])
Couple of points,

Here X is N x p matrix, Y is N x 1 vector of binary variables and
tasks
is N x 1 vector of task indicators. For example if M = 3 thentasks
will hold entries [0, …, 0, 1, …, 1, 2, …2] 
The first for loop is over K groups which simulates the Chinese restaurant process. Inside the same loop I also sample K weights

The second for loop is over M tasks. Here I sample c_m categorical variables from my Dirichlet distribution. Here I also instantiate a onelayer NN with sigmoid activation which is equivalent to logistic regression. This NN code is modified from this Pyro example. The last line treats the observations as sampled from a Bernoulli.
Questions

Does this Model look alright?

I assume there is no better way to do logistic regression in Pyro than via a NN.

I have appended all my weights, W, into a
priors
list. Is this problematic during inference? Is pyro smart enough to match variables by label names no matter how they are organized. 
In my model I would like to learn pi_k (crp_pi in the Model code above) as well. I am not sure how to set this up to be a pyro.sample primitive since it involves the Chinese restaurant process.

Lastly, I am not sure how to set K for Chinese restaurant process when K is independent of my data examples. The right way to do this is to simulate the stick breaking approach using Beta distribution. For now I am not too worried about hard coding K to be 100
Any help is appreciated. I have been at this for two days now.