AutoGuide For a deep learning MLP Model

Hello,
I am not sure that I understand well how the AutoDiagonalNormal AutoGuide works. I have a MLP model and I want to do an SVI on this model.
Given that my model uses PyroSample statements (that trigger pyro.sample statements), I guess that my guide has the same structure (except that I can’t add a name to these PyroSample statements as I usually do with pyro.sample). But my real question is what about the activation functions. Between each layer I am using a GeLU but are they really “copied” in my guide or is my guide more like a MeanFieldGuide without any activation Functions? Thank you in advance

AutoDiagonalNormal assumes that all of the model parameters have normal distribution and more importantly are independent, hence the Diagonal in the name. The diagonal refers to the covariance matrix and since it is a diagonal there is no link among parameters, everything else is zero. After training, if you check the param store you will have paramX_loc and paramX_scale for each parameter in you model (in this case NN).

Okay thank you very much for your answer, but I don’t see the use of having a neural network if you don’t add non linearities with your activation functions. In this case I would just be doing a linear regression of the forward of my guide (but my model, hence my prior would still be extremely complex because of all the layers and all the activation functions). Thus I would like to define a more complex guide with activation functions. I thought of an Easy Guide but here is my main problem : I am using a Pyro Module that uses PyroSample, (and not pyro.sample where you can name your parameters). Thus how should i do to have a “better” guide (here is the link the detailed question I posed i the forum Create an EasyGuide for a PyroModule model, SVI )
Have a great day,
Rémi

PS : When I look at the parameters of my guide after training, the paramX_loc seem to respect some kind of gaussian distribution but my paramX_scale all have the same value, which isn’t the case if i use a AutoLowRankMultivariateNormal and I don’t understand well why is that ? How come the fact that adding dependencies allow all the standard deviations to be different where removing dependencies between my weights gives them the same value ?

Let’s do it one step at a time. So first NN. If you ignore VB, and you train the NN, you will end up with point estimates for each parameter in you NN. This is regarless of the form of the NN. Using VB is changing the behaviour and instead of getting point estimates, for each parameter in your NN you will get distributions. Then when you do predictions you don’t get fix values but samples from the predictive distribution.

If you say that your NN has parameter set \theta, and your data set is X, then your learing step is estimating \hat{\theta} that best describe your data. In VB sence you are not looking for point values but for the posterior distribution p(\theta|X) or in simple words, what is the distribution of my parameters based on the observed data.

The solution of the posterior is difficult (intractable in general case) and VB is trying to approximate p(\theta|X) with a variational distribution q(\theta). This is your guide.

So your model is the NN with all nonlinearities etc., the guide is the approximative distribution of the posterior.

PyroModule has automatic naming. Check here https://pyro.ai/examples/modules.html#How-naming-works.

Can you post your code, maybe I can help you better?

Thank you very much for your answer. I understand rthat the guide is a tractable approximation of the posterior distribution but what i don’t understand is how is it defined concretely.
My model has a forward method and i guess my guide too but what happens concretely when i do a predictive ? I understand every term of the formula and I guess there must be some kind of forward in my guide so going through this forward, is it like going through the forward of a Neural Network with activations functiuns (my posteriors has the weights of every locs and scales of the neurons of a MLP when i do a pyro.get_param_store() but are there the same activation functiuns too in my posterior? If my model is the following does my guide have the same shape (does my guide have activations functiuns too)?

class MLP(PyroModule):
def init(self, in_dim=1, out_dim=1, hid_dim=4, n_hid_layers=5):
super().init()
self.activation = nn.GELU() # could also be ReLU
self.layer_sizes = [in_dim] + n_hid_layers * [hid_dim] + [out_dim]
layer_list = [PyroModule[nn.Linear](self.layer_sizes[idx], self.layer_sizes[idx-1]) for idx in
range(1, len(self.layer_sizes))]
self.layers = PyroModuletorch.nn.ModuleList
for layer_idx, layer in enumerate(self.layers):
layer.bias = PyroSample(dist.Normal(0, 1).expand([self.layer_sizes[layer_idx + 1]]).to_event(1))

            layer.weight = PyroSample(dist.Normal(0,1).expand(
                [self.layer_sizes[layer_idx + 1], self.layer_sizes[layer_idx]]).to_event(2))
                
 def forward(self, x, y=None):
    x=x.reshape(1,-1)
    x = self.activation(self.layers[0](x))  
    for layer in self.layers[1:-1]:
        x = self.activation(layer(x))  +
        
    mu = self.layers[-1](x).squeeze()  
    sigma = pyro.sample("sigma", dist.Gamma(.5, 1))

    with pyro.plate("data", x.shape[0]):
        pyro.sample("obs", dist.Normal(mu, sigma * sigma), obs=y)
    return mu

model = MLP()

guide = AutoDiagonalNormal(model)

Second question, the AutoDiagonaGuide is defined by AutoDiagonalNormal(model, init_loc_fn=init_to_median, init_scale=0.1). Why this value 0.1 for the scale? Isn’t there a way to do some kind of init_to_median. If new data arrives and i want to update my model (like with the methodology of online learning) and redefine my posterior afterward, I am losing all the information that I had on my standard deviation… (ie : in online learning I keep on redefining my prior and then posterior at each learning step but then my posterior’s scale never learns and is always 0.1…)

The easest way is to see numpyro source it is much simpler. The predictive can be found here. In essence first the guide is sampled for each parameter. Then using these samples the model is evaluated.

If we look at Bayes’ the predictive distribution is the probability of observing new x if we’ve previously seen trained data \mathbf{X} or:

p(x|\mathbf{X}) = \int_z p(x|z)p(z|x) dz.

Since we don’t know the true posterior p(z|x) but the approximation through the guide q(z) then the above relation becomes:

p(x|\mathbf{X}) = \int_z p(x|z)p(z|x) dz \approx \int_z p(x|z) q(z) .

In order to solve the above integral, pyro uses monte carlo through the Predictive class. As I wrote in the first paragraph, it first samples from the guide i.e. q(z) then runs these through the model thus evaluating p(x|z) and does integration.

So the function forward is actually called during predictive. Check poutine.condition on how this can be done by hand.

Regarding the second question scale for Gaussian distribution is actually the stadard deviation. So 0.1 is 10% around the mean which is like a rule of a thumb good starting choice. You can do your initiation as you wish, you can customize this, check https://docs.pyro.ai/en/stable/infer.autoguide.html. As you can see, you can use a function that determines initial values.

Thank you very much it is clearer now. I find strange that the Autoguides give more importance to the loc than to the scale.