Hi all, using Pyro version 1.5.0 to implement AIR. I hope this topic will create a discussion and/or give advice to others re-implementing it.
I am not following the Pyro tutorial for it strictly, but I am using the idea of converting a categorical distribution over the number of targets probabilities to bernoulli probabilities in order to use a plate for minibatching. I have between 0 and 2 (inclusive) objects in each example image of my dataset, I choose a maximum number of 5 (inclusive) for the model to guess.
I use TraceEnum_ELBO with parallel enumeration over the ‘z_pres’ sample site. Initially, I find that the model always prefers to assign overwhelmingly the probability mass to 5 objects, then, even if there are actually 2 objects in the scene, it proceeds to reconstruct only one by using all 5 steps on it.
My experiments are then as follows:
- Changing the categorical prior over object number to prefer smaller numbers - same outcome.
- Adding an extra loss term proportional to the sum of inverse pairwise distances between the object positions inferred by the model. This does discourage the model to put the objects strictly on top of each other, but the number of steps it uses still is the maximum possible. Some hyperparameter tuning here to strengthen this loss term might help.
- Using the “masked air” version from the paper - masking the region selected after each step and then encoding the whole image again to choose the next glimpse. This version of the AIR model actually learns to see more than one object in my experiments.
Still, some problems are common to all, such as:
- Using all (max number) of steps even if the prior discourages this (and I don’t want to put any stronger prior on this simply because it does not make any intuitive sense, ideally I would be either uninformed about the number of objects or have the categorical probabilities as a parameter that I would learn from my data stream
- A step emitting an empty object
If anyone has had other interesting thoughts or experiments on this model, I would love to hear them! I am also unsure about the hyperparameter tuning to “strengthen” the pairwise distance loss term and would be interested if anyone had any ideas how to implement a prior on positions of objects that would place low probability density on regions where objects were already emitted in previous steps of the generation process.