Thanks for a reply @fritzo
The guide parameters’ shape is just the flattened model parameters’ shape concatenated together. And then why does it need to see (*args, **kwargs) to get its parameters initialized?
And, Is the guide parameters’ initialization dependent on the data we pass in?
why does [an autoguide instance] need to see (*args, **kwargs) to get its parameters initialized?
Autoguides determine the latent shapes of their models by running and tracing the model once. To trace the model, the guide needs to pass the model proper *args,**kwargs.
Is the guide parameters’ initialization dependent on the data we pass in?
For many models, autoguide parameter initialization does not depend on the data passed into the initial guide call. However parameters can depend on initial data if you use a parameter-dependent initialization strategy such as init_to_median or init_to_sample, and if the parameters of your model depend on input data (e.g. if input data are covariates or features describing distribution parameters).
Note autoguides require full data to be passed in, and do not support amortization, i.e. passing in different minibatches each learning step. To support subsampling / amortization you should either write a custom guide or use an EasyGuide.
I experimented with EasyGuide and AutoNormal for a small Linear Regression problem. But I find that AutoNormal seems to work on minibatched data.
Here is the experiments notebook. I am not sure if AutoNormal follows the behaviour you said above. Can you have a look?
Hi @sayam049, I cannot read your linked notebook (error 503), but if you’d like to use AutoNormal with subsampling, I believe you’ll need to pass in a create_plates argument that returns a plate with the subsample argument set just as you do in the model; that way the guide and model plates are properly aligned. We should definitely have better docs on this (feel free to submit a PR)… here’s an example
def model(full_data, subsample_indices):
with pyro.plate("my_plate", len(full_data), subsample=subsample):
batch = full_data[subsample]
...
def create_plates(full_data, subsample_indices): # same args as model
# this line is copied from the model
return pyro.plate("my_plate", len(full_data), subsample=subsample)
guide = AutoNormal(model, create_plates=create_plates)
Hi @sayam049, thanks for linking your notebook. It looks like in your case you don’t need create_plates because there are no latent variables inside the plate. A vanilla AutoNormal(model) should work fine.