Batch processing Pyro models

Does anyone know if it is possible to run batches of Pyro model runs in parallel, using tools like Dask? By this, I mean its the same model being passed different subsets of data. With over 100 of these batches, I am keen to parallelize them if possible. I did a simple trial run using dask.delayed, but am getting some fairly inscrutable errors from Pyro. I just wanted to find out if this is feasible and pursue this further, or if I am wasting my time.

Thanks,
cf

Can you provide some details about your problem and the errors you’re seeing? Do you mean that these runs are entirely independent, or are you attempting to parallelize part of a large hierarchical model? There is no barrier in principle to parallelizing over independent model runs. You might also try Ray as an alternative to Dask that might work better with PyTorch or use jax.pmap if you are using NumPyro.

Yes, sorry for the lack of detail. I have a Pyro model encapsulated in a function called run_model that is estimated using SVI:

def run_model(inputs):

    def my_model(inputs):
        <specify model here>

    guide = autoguides.AutoNormal(my_model)
    optimizer = pyro.optim.ClippedAdam({"lr": initial_lr, "lrd": lrd})
    svi = SVI(my_model, guide, optimizer, loss=Trace_ELBO())

    pyro.clear_param_store()
    for j in range(N_STEPS):
        loss = svi.step(*data)

And this model is called with different inputs in a loop, with the intent of them being calculated in parallel by Dask with dask.delayed:

results = []

for inputs in input_list:
    estimates = dask.delayed(run_model)(*nputs)
    results.append(estimates)

output_df = pd.concat(dask.compute(*results)

However, this fails despite running without error when called individually (i.e. outside of Dask loop).

Msg 39019, Level 16, State 2, Line 1
An external script error occurred: 
  File "C:\Program Files\Microsoft SQL Server\MSSQL14.MSSQLSERVER\PYTHON_SERVICES.3.7\lib\site-packages\pyro\poutine\plate_messenger.py", line 21, in __enter__
    super().__enter__()
  File "C:\Program Files\Microsoft SQL Server\MSSQL14.MSSQLSERVER\PYTHON_SERVICES.3.7\lib\site-packages\pyro\poutine\indep_messenger.py", line 84, in __enter__
    self.dim = _DIM_ALLOCATOR.allocate(self.name, self.dim)
  File "C:\Program Files\Microsoft SQL Server\MSSQL14.MSSQLSERVER\PYTHON_SERVICES.3.7\lib\site-packages\pyro\poutine\runtime.py", line 35, in allocate
    raise ValueError('duplicate plate "{}"'.format(name))
ValueError: duplicate plate "levels"
Trace Shapes:      
 Param Sites:      
Sample Sites:      
      mu dist | 415
        value | 415
Trace Shapes:
 Param Sites:
Sample Sites:

Note that the error claims that there are multiple plates called levels but there is only one in the model. This leads me to suspect that they are somehow not being run independently (?)

I see, I’m guessing the problem is that there is several pieces of global state in Pyro’s internals that were not designed with Dask-based concurrency in mind, and you ran into an error with one of them. I would also not be surprised to see problems with parameters from different runs overwriting each other in the global parameter store.

I’m not familiar with Dask and AFAIK none of the Pyro team are regular Dask users, so unfortunately I can’t immediately understand the root cause or suggest a workaround other than trying a different library for parallelism. I believe multiprocessing and torch.multiprocessing do not have this kind of issue if you are OK with a single machine, and I suspect Ray is also compatible but have not tried.

It would also be really helpful if you could open a Pyro GitHub issue with a small runnable script and machine configuration information that reproduces your error so that we could look into what it would take to support Dask properly - it might be fairly easy to fix, or at least help us suggest a workaround that requires less user effort.