Benchmark svi

How can I find the bottleneck in the model / guide in svi.

Related to this is the control flow for svi. Using a logger / print statements, it seems the control flow is handled by the model, which goes to the guide when it hits a sample statements - then it goes to the guide and runs things in the guide until it gets a sample with of the named random variable. How exactly do things happen internally at the backend of pyro’s svi?

the basic logic is explained here

basically the guide is run forward and then the model is run forward. when the model is run forward any encountered sample sites are “replayed” from the guide samples

Ok, so the guide sets the control flow. Thanks for pointing me to that!

Is there some examples where model and guide take some params when they are passed to pyro.infer.SVI?

I have been doing something like this

model_wrapped = model_wrapper(model_maker, params)

svi = pyro.infer.SVI(model=model_wrapped, 
                     guide=guide, 
...)

Hopefully I can pass in some params to model and guide somewhere with *args and **kwargs.

yes the SVI step method takes args and kwargs. there are examples of this just about everywhere in the docs/examples; it’s hard to miss: e.g. here

It seems that the model and guide need to both be able to accept the argument. So even if I don’t use it, it needs to be able to read it in.

Any suggestions on the software engineering side, so that I don’t just have dummy variables that are tossed aside?

Right now I’m doing the following

def model(..., for_model_and_guide1, 
for_model_and_guide2, 
for_model_1, 
for_model_2
):
  # use for_model_1 and for_model_2
  ...
  return ...


def guide(..., 
for_model_and_guide1, 
for_model_and_guide2, 
for_model_1, 
for_model_2
):
   # don't use for_model_1 or for_model_2
  ...
  return ...

svi.step(data, 
for_model_and_guide1=..., 
for_model_and_guide2=..., 
for_model_1=..., 
for_model_2=...
)

pass a dict, let each model, guide pair pull out what they need from the dict

def model1(dataset):
    x = dataset["x"]
    ...

def model2(dataset):
    y = dataset["y"]
    ...
1 Like