afaik the place to start would be initialize_model
although @fehiepsi might have a better suggestion. note it might be a bit complicated to get all the details right, given the complexity of transforms and the like.
also note that any pyro implementation of uha/dais is expected to be somewhat slow compared to numpyro. basically because there will be overhead from a python for loop, which jax can compile away. also pytorch isn’t great about optimizing gradients. this is one of the reasons we implemented these methods in numpyro.