Using TransformReparam with to_event

I am trying to use TransformReparam in order to develop a non-centered version of my model. The relevant code is:

with numpyro.handlers.reparam(config={"user_prefs": TransformReparam()}):
    user_prefs = numpyro.sample(
        "user_prefs",
        dist.TransformedDistribution(
            dist.Normal(jnp.zeros((n_users, n_categories)), 1.0),
            transforms=[
                dist.transforms.AffineTransform(user_mus, user_sigmas),
                dist.transforms.ExpTransform(),
            ],
        ).to_event(1),
    )  

and I got the error:

  File "/Users/fwilhelm/.miniconda3/envs/cf-model/lib/python3.8/site-packages/numpyro/handlers.py", line 497, in process_message
    new_fn, value = reparam(msg["name"], msg["fn"], msg["value"])
  File "/Users/fwilhelm/.miniconda3/envs/cf-model/lib/python3.8/site-packages/numpyro/infer/reparam.py", line 131, in __call__
    assert isinstance(fn, dist.TransformedDistribution)
AssertionError

which is due to the fact that to_event converts the TransformedDistribution into distribution.Independent. The thing is that I need to call to_event on the transformed distribution due to the plate I am using. Reformulating everything without the with numpyro.handlers.reparam context works. Am I doing something wrong here or is this a conceptional error? Could I somehow apply the to_event right after the context is left?

Hi @FlorianWilhelm,
That looks like a bug in TransformReparam; I’ll try to reproduce and fix. You may be able to workaround by moving the .to_event() into your Normal base distribution:

with numpyro.handlers.reparam(config={"user_prefs": TransformReparam()}):
    user_prefs = numpyro.sample(
        "user_prefs",
        dist.TransformedDistribution(
            dist.Normal(jnp.zeros((n_users, n_categories)), 1.0).to_event(1),
            transforms=[
                dist.transforms.AffineTransform(user_mus, user_sigmas),
                dist.transforms.ExpTransform(),
            ],
        ),
    )  

Cheers.

Thanks, @fritzo for the fast response. Is it in general equivalent if I use to_event on the inner distribution or on the outer TransformedDistribution? It seems especially when using discrete latent variables I got a lot of shape errors like:

ValueError: Incompatible shapes for broadcasting

having to_event(1) on the inner distribution which are resolved by moving it to the outside.

Hi @FlorianWilhelm,
In some cases it is equivalent to move .to_event() inside the TransformedDistribution, and I believe it is correct in your example of this question. In particular it is ok of all your transforms are scalar transforms like ExpTransform() and not batched like AffineTransform(an_array_with_bigger_shape_than_the_base_dist, 1.).

BTW here are the PRs fixing the bug you found. Thanks for reporting.


Hi @fritzo,

thanks! Wow, that was very, very fast. Do you happen to know when there will be the next release including this fix?

Regarding the position of to_event I played a bit more around with your suggestion of moving it to the inner distribution. Using a centered version works for me:

user_prefs = numpyro.sample(
    Site.user_prefs,
    dist.TransformedDistribution(
        dist.Normal(loc=user_mus, scale=user_sigmas).to_event(1),
        transforms=[
            dist.transforms.ExpTransform(),
        ],
    ),
)

Now replacing this code with:

with numpyro.handlers.reparam(config={Site.user_prefs: TransformReparam()}):
    user_prefs = numpyro.sample(
        Site.user_prefs,
        dist.TransformedDistribution(
            dist.Normal(jnp.zeros((n_users, n_categories)), 1.0).to_event(1),
            transforms=[
                dist.transforms.AffineTransform(user_mus, user_sigmas),
                dist.transforms.ExpTransform(),
            ],
        ),
    )  # dim: (n_users | n_categories)

results in the incompatible shape ValueError but later in my code and judging from the shape it only happens in the part of the compilation step where discrete latent variables are added as additional shape dimensions (at least that’s my interpretation from trying to debug it). This is why I think in the case when latent discrete variables are marginalized out, there might be a difference where to_event is applied.

I also like to add that what I said before applies to running MCMC:

mcmc = MCMC(NUTS(model), num_warmup=0, num_samples=10)
rng_key = random.PRNGKey(0)
mcmc.run(...)

When I do SVI even moving to_event in the centered version (first code block) from outer to inner results in an incompatible shape error. I would really love to understand this more. Why does it make a difference here if I am doing SVI vs MCMC? Thanks for any pointers!

Hi @FlorianWilhelm, could you post a repro code in github so we can test it? Currently, NumPyro does not support enumeration with SVI yet, so I guess you didn’t use enumeration in SVI right? For MCMC, looking like n_users is a plate dimension? For models with discrete latent variables, you would need to declare all plate dimensions.

FYI, we intended to release numpyro tomorrow. :slight_smile: I hope to fix your issue before that.

Hi @fehiepsi, thank you so much for looking into this! I granted you and @fritzo access to the repo https://github.com/FlorianWilhelm/cf-model/ The code there is GPL and my plan is to make the repo public later (if my idea works I want to publish something first), thus it’s currently private. Let me know if I should add others to the codebase.

Currently, to debug I added the scripts:

  • scripts/run_handler.py which uses SVI
  • scripts/run_mcmc.py which uses MCMC

The main model codes reside in cf_model/bayes/model.py. If you have any comments or feedback on my code or in general about my Bayesian modelling skills,don’t hesitate to tell me :slight_smile: I would be very grateful to learn from an expert like you as I am still quite new to NumPyro and maybe there are some quite stupid errors in my code.

Hi @FlorianWilhelm, I took a look at the model and got a few comments:

  • It is better to explicitly specify dim in squeeze(...). For marginalization, we usually add singleton dimensions to the priors for broadcasting so using squeeze() can mess those dimensions.
  • Enumeration does not work in some situations like this. In your model, e.g., the observation lies in a different plate w.r.t. the discrete site Site.item_cat_idx but depends on that discrete site.

I think you can use DiscreteHMCGibbs for your model to avoid shape issues. :slight_smile:

The best way for me to decide if enumeration will work in a model is to draw a plate diagram of that model. If from a discrete variable, there is a link that goes outside the plates containing that variable, then enumeration won’t work. Taking a look at many plate diagrams here, you will see that there is no such link. Hence enumeration will work for such models. Could you post the plate diagram of your model somewhere?

Hi @fehiepsi, thanks a lot for your feedback!

I started with converting my hand-written model notes and drawings to Markdown/LaTeX and Matplotlib in a Jupyter notebook. You can find the generative process and the plate diagram here. Also thanks for pointing out the paper Comparing Bayesian Models of Annotation. It was really helpful for me.

Also thanks for referring me to the various restrictions of enumerations. To my understanding, the restriction Restriction 2: no downstream coupling applies in my case, and thus the code should not run, right? But it does now using SVI and a centered version, avoiding reparam. Does that mean the model trains but the results will be just wrong? I have had no time to check them thoroughly but they don’t look so well anyway.

I haven’t used DiscreteHMCGibbs yet but would you recommend MCMC in my use-case? Since my data will be quite big in the end, I need mini-batches and thus SVI was way more appealing to me compared to some MCMC approach.

My model is basically some hierarchical gaussian-like mixture model. The way I understand enumerations in NumPyro is that for each possible category ELBO is calculated (which is thus generating one additional dimension for each enumeration). In the end, this additional dimension is collapsed by calculating some kind of expected ELBO using the probabilities for each category. Is my understanding correct? And if it is so, could I then just do this explicitly by using the categorical probabilities directly for weighting instead of generating a discrete latent variable using the Categorical distribution?

Sorry for asking so many questions but as wonderful NumPyro is, sometimes is also quite difficult :slight_smile:

Best regards, Florian

Yes, the arrow from the discrete latent variable phi going to r is moving outside of plate U, so enumeration won’t work. If U is small, you can use for loop instead of using the plate U like this model.

The reason why enumeration won’t work here is: for each user, there are K possible types, then the number of possible values of U-length vector x is K^U. Assume the number of users is 20, the number of types is 10, then there are 10^20, which is quite large, possible values of x.

If x is not used for something else (which is not true in your model because r uses it), then for each user, we can enumerate over all K possible type values to calculate some of his/her statistics, e.g. the log probability. The size of that statistics will be the same size as the number of users, which is 20. This number is pretty small comparing 10^20. This is also aligned with your explanation for enumeration.

NumPyro does not support enumeration in SVI yet (see this issue). Your code runs because it does not use enumeration. ELBO formula used for discrete variables is just the same as for continuous variables.

sometimes is also quite difficult

Could you be more explicit on which difficulties that you have? Sometimes doing Bayesian inference is difficult and we hope that Pyro/NumPyro will make some difficult stuff easier. Your feedback would be very valuable!

The problem with your model is in principle, we can’t enumerate it because the number of possible values to enumerate will be very high (this is agnostic to the PPL that you use).

recommend MCMC in my use-case

I’m not sure. You can try HMC within Gibbs to see if it is fast enough for your problem and gives any meaningful result. There is HMCECS that supports subsampling HMC but does not support discrete latent variables yet (I guess combining them would not be so complicated so please make a feature request for it if you want). Using HMC for models with discrete latent variables and subsampling are research problems and I’m not sure if they will work for your problem.

If you use SVI, I would recommend using Pyro, which has some helpful categorical distributions and better support for gradient estimate and reducing variance of ELBO.

Okay, now I fully understood the restriction 2 of enums, thanks for explaining!

Regarding the point that NumPyro does not support enumeration in SVI. Sorry for my mistake here as I was mixing Pyro and NumPyro documentation here and was under the assumption that my code is using enumeration support automatically since I use several Categorical distributions. Bringing me to one more questions:

What does SVI do in the case of categorical distributions in NumPyro? There is no corresponding sample site in the guide so I assume it’s still marginalized out albeit not using enumerations. My take from the last link in your last post is that some surrogate objective is used. This could lead to unintended results I assume and thus your pointer to rather using Pyro instead of NumPyro for my use-case.

I didn’t want to sound impolite by saying that

sometimes [NumPyro] is also quite difficult

As you say, the topic Bayesian Inference is quite hard itself and as awesome NumPyro is, it cannot abstract away the whole complexity that comes with it. But I think in general, NumPyro is doing a great job at that and I deeply appreciate all the hard work you and other contributors put into NumPyro. Thanks for that! I will try to give as much constructive feedback as possible so that also the perspective of a NumPyro user is known for future development decisions.

1 Like

Hi @FlorianWilhelm, unfortunately, the surrogate objective is not implemented yet, so the gradient of ELBO will be off for non-reparameterized distributions. I agree that it could lead to unexpected results. We should add a warning in SVI docs.

it cannot abstract away the whole complexity that comes with it

Agreed! I really hope that the Pyro community can iteratively improve the framework through feedbacks, documentation, examples, tutorials, blog posts,… and share the (practical) experience with other users. Thanks for your words!

By the way, you are so fast in understanding that enumeration restriction. It literally took me years to figure it out. Haha

1 Like

@FlorianWilhelm Looking like it is not so complicated to implement surrogate loss in NumPyro. Could you make a FR for it (and might be also a combine of HMCECS+DiscreteHMCGibbs) if you think it is necessary for your model?

Thanks, @fehiepsi, that would be cool! I created an [FR] issue.

I haven’t yet looked into HMCECS+DiscreteHMCGibbs. From your former comment, I understood that HMCECS allows me to do subsampling. In the docs of HMCECS, it says that this needs to be done with a plate directive. In my model, on each epoch the negative samples, i.e. no interaction, are newly drawn thus my batching logic is outside the model. To make up for this I use:

 with numpyro.handlers.scale(scale=n_interactions / n_batch):

I am thus not sure that my use-case would work with HMCECS.

Do you think the combination of HMCECS+DiscreteHMCGibbs would work? You said that

Using HMC for models with discrete latent variables and subsampling are research problems and I’m not sure if they will work for your problem.

So DiscreteHMCGibbs would take care of my discrete latent variables and HMCECS would take care of the subsampling? But since you said that HMCECS does not allow discrete variables how would I combine these? Sorry, I am a bit confused about that point.

In general, I guess my model will have quite a lot of latent variables. If the small data set works, numbers of users, as well as items, could easily be in the millions each having many latent variables. Thus I assumed that SVI would be the only way to handle something as big as that.

Thanks, @FlorianWilhelm!

Re DiscreteHMCECS (let’s just call it that name :smiley: ): Currently, in an MCMC step of DiscreteHMC, we updated the discrete latent variables first, then run HMC. In an MCMC step of HCMECS, we updated the subsample indices first, then run HMC. But we haven’t combined them. We can implement DiscreteHMCECS by either

  • updating discrete -> updating subsample -> running hmc
  • or updating subsample -> updating discrete -> running hmc

I am not sure which order makes more sense (ideally we should refactor the implementation so that users can freely compose the kernel like (HMCECS(DiscreteHMCGibbs(model)) or DiscreteHMCGibbs(HMCECS(model)), but for now I’m not sure how complicated the implementation will be) but they are pretty straightforward to implement.

this needs to be done with a plate directive

Good point! For it, we require users provide full data and subsample it under the plate:

with plate("data", size, subsample_size) as subsample_idx:
    data = data[subsample_idx]
    ...

and behind the sense, we will resample a new subsample_idx in each MCMC step.

If the small data set works, numbers of users, as well as items, could easily be in the millions each having many latent variables. Thus I assumed that SVI would be the only way to handle something as big as that.

I think so too. I don’t have experience with this kind of complicated models but if facing it, I will try SVI first. I’m curious on which method people uses in those situations. You can create a new forum topic to discuss practical tips for that I guess. :slight_smile:

1 Like