`scan` does not work when using `HMCECS`

Hi! It seems that scan does not support subsampling although I have implemented it by figuring the init parameter of scan with subsample size. So HMCECS does not recognize subsampling. I have to use the sequential form of my model to make sure it worked properly in HMCECS. Is there any way to let subsampling work in scan? I really appreciate it!

Both scan primitive and HMCECS are complicated to implement. It seems to me that it’s tricky to make them compatible. Can you rewrite your model to use jax’s scan instead?

Sure! I can make the scan form of my model work (without subsampling, and on a small size of subsampled dataset). However I have a large dataset so I have to considering using HMCECS to scale my model to this dataset.

Which scan operator did you mention in the title? HMCECS does not work with numpyro’s scan but I think it will work with jax’s scan.

Sorry I misunderstood your last reply. But actually, when I was using for loop in my model, HMCECS still worked will and the sample speed is acceptable. I visualized the sampled result just now and found that it seemed not converged. I didn’t apply Taylor proxy with which I’m not familiar. I found that faster sampling seemed to mean that the sampling was about to converge. But in my model, the sampling rate is always slow, about 10s/it, until the sampling ends. I set warmup=12000 and sample=3000. Do you think more samples are needed for a complex model?

The scan operator I mentioned is numpyro.contrib.control_flow.scan

I see. If you rewrite your model/likelihood using jax.scan (with numpyro.factor statements), then I think HMCECS will work.

Re sampling behavior: it’s a research topic, I guess. It’s better to reach out to HMCECS authors for recommendations.