SVGD with batch training?

Hi!

I would like to use SVGD as an alternative to SVI for gradient descent. Unfortunately, I require the model to train in batches and SVGD seems not to be compatible with that, since I am obtaining the error:

  File "/home/user/anaconda3/lib/python3.7/site-packages/pyro/poutine/broadcast_messenger.py", line 34, in _pyro_sample
    f.name, msg['name'], f.dim, f.size, target_batch_shape[f.dim]))
ValueError: Shape mismatch inside plate('num_particles_vectorized') at site z_1 dim -1, 2 vs 200    

My SVGD implementation is:

   adam_params = {"lr": args.learning_rate, "betas": (args.beta1, args.beta2),
                  "clip_norm": args.clip_norm, "lrd": args.lr_decay,
                  "weight_decay": args.weight_decay}
   adam = ClippedAdam(adam_params)
   elbo = JitTrace_ELBO(num_particles=2) if args.jit else Trace_ELBO(num_particles=2) 
   kernel = RBFSteinKernel()

   Sgd = SVGD(model,kernel, adam, num_particles=2,max_plate_nesting=0)

My batch size is 200, I have tried pairing both the batch size and the number of particles but I wasn’t successful,

Thank you very much for your attention and time :slight_smile:

it’s hard to answer your question without more details. svgd should be compatible with data subsampling (at least if there are no local latent variables) as long as you use plates correctly. the change may be as simple as invoking max_plate_nesting=1 in your model (hard to say, since i can’t see your model). in any case i recommend you make yourself familiar with the logic of our tensor shapes tutorial.