Scan + Condition + TransformReparam

Following the sgt example I set my observed within the scan using handlers.condition, this works well for a standard distribution but fails for a transformed distribution. Is it possible with NUTS to set the sample for an observed, transformed distribution within a scan?

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, init_to_median, Predictive
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer.reparam import TransformReparam
from jax import random
import arviz as az
import jax.numpy as jnp
import jax
import pandas as pd
numpyro.set_host_device_count(4)
np.set_printoptions(suppress=True)

sample_params = {
    'num_chains':4, 
    'num_warmup':1000, 
    'num_samples':1000, 
    'progress_bar':True,
}

#generate AR1 process
N, K = 200, 5
scale = 2.
z = np.random.RandomState(8).normal(scale=scale, size=(N, K))
theta = -0.2
y = np.zeros_like(z)
y[0, :] = z[0, :]
y[1:, :] = z[:-1, :]*theta + z[1:, :]
y = jnp.array(y, dtype=jnp.float32)

def model1(N, K, N_future, y):
    scale = numpyro.sample('scale', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    
    def transition(y_prev, t):
        loc = y_prev*theta
        y_ = numpyro.sample('y', dist.Normal(loc=loc, scale=scale), )
        return y_, y_

    init_0 = y[0]
    with numpyro.handlers.condition(data={'y': y}):
        _, ys = scan(transition, init_0, jnp.arange(N + N_future), )
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])

hmc = MCMC(NUTS(model1, target_accept_prob=0.9, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, K=K, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata)

def model2(N, K, N_future, y):
    scale = numpyro.sample('scale', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    
    def transition(y_prev, t):
        loc = y_prev*theta
        y_ = numpyro.sample('y', 
                            dist.TransformedDistribution(
                                dist.Normal(loc=loc, scale=1.), 
                                AffineTransform(0, scale)
                            )
                           )
        return y_, y_

    init_0 = y[0]
    with numpyro.handlers.condition(data={'y': y}):
        _, ys = scan(transition, init_0, jnp.arange(N + N_future), )
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])

hmc = MCMC(NUTS(model2, target_accept_prob=0.9, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, K=K, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata)


def model3(N, K, N_future, y):
    scale = numpyro.sample('scale', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    
    def transition(y_prev, t):
        loc = y_prev*theta
        with numpyro.handlers.reparam(config={"y": TransformReparam()}):
            y_ = numpyro.sample('y', 
                                dist.TransformedDistribution(
                                    dist.Normal(loc=loc, scale=1.), 
                                    AffineTransform(0, scale)
                                )
                               )
        return y_, y_

    init_0 = y[0]
    with numpyro.handlers.condition(data={'y': y}):
        _, ys = scan(transition, init_0, jnp.arange(N + N_future), )
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])

hmc = MCMC(NUTS(model3, target_accept_prob=0.9, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, K=K, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata)
  • model1() runs fine and and recovers the true parameters in my minimal working example.
  • model2() fails with error:
    NotImplementedError: Flatenning TransformedDistribution is only supported for some specific cases. Consider using TransformReparam to convert this distribution to the base_dist, which is supported in most situtations. In addition, please reach out to us with your usage cases.
  • model3() runs without error but the condition handler doesn’t work to set the observed for y

Here I am using AffineTransform which I could easily do without, as in model1(), but it’s just to display the error.

We currently do not have sort of “jit” rule for bijector transform, so we can’t use transform distribution in scan without reparam handler. Please make a feature request for it if you think it is important to have. For model3, it is a pending issue. As a temporary solution, you can condition on y_base using data={'y_base': AffineTransform(0, scale).inv(y)}.

Thanks for the quick reply, however something isn’t quite working

It runs without error and converges but it doesn’t recover the good values

def model3(N, K, N_future, y):
    scale = numpyro.sample('scale', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    
    def transition(y_prev, t):
        loc = y_prev*theta
        with numpyro.handlers.reparam(config={"y": TransformReparam()}):
            y_ = numpyro.sample('y', 
                                dist.TransformedDistribution(
                                    dist.Normal(loc=loc, scale=1.), 
                                    AffineTransform(0, scale)
                                )
                               )
        return y_, y_

    init_0 = y[0]
    with numpyro.handlers.condition(data={'y_base': AffineTransform(0, scale).inv(y)}):
        _, ys = scan(transition, init_0, jnp.arange(N + N_future), )
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])

hmc = MCMC(NUTS(model3, target_accept_prob=0.9, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, K=K, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata, var_names=['scale', 'theta'])

the results are:
image

Taking a simpler example to see if conditioning on the inverse transform works and it seems to not.

X = np.random.RandomState(8).normal(loc=0, scale=8, size=200)
def m(X):
    loc = numpyro.sample('loc', dist.Normal(0, 1))
    scale = numpyro.sample('scale', dist.Exponential(1))
    with numpyro.handlers.condition(data={'X': AffineTransform(loc, scale).inv(X)}):
        obs = numpyro.sample('X', dist.Normal(0, 1), )

hmc = MCMC(NUTS(m, ), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), X=X)
idata = az.from_numpyro(hmc, )
az.summary(idata, var_names=['loc', 'scale'])

It fails to recover the true scale of 8, though it does converge and also move away from the prior
image

Oops, sorry, I guess we need an additional jacobian term?

    X_base = AffineTransform(loc, scale).inv(X)
    with numpyro.handlers.condition(data={'X': X_base}):
        obs = numpyro.sample('X', dist.Normal(0, 1), )
        sign = 1  # or -1, I'm not sure for now
        numpyro.factor('logdet', sign * AffineTransform(loc, scale).log_det_abs_jacobian(X_base, X))
1 Like

You are a maestro, thanks very much. Sign should be -1.

Let me try applying on my real model now.

I have a follow up question related to the above

Here is the new model code

import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.contrib.control_flow import scan
from numpyro.infer import NUTS, MCMC, init_to_median, Predictive
from numpyro.distributions.transforms import AffineTransform
from numpyro.infer.reparam import TransformReparam
from jax import random
import arviz as az
import jax.numpy as jnp
import jax
import pandas as pd
numpyro.set_host_device_count(4)
np.set_printoptions(suppress=True)


sample_params = {
    'num_chains':4, 
    'num_warmup':1000, 
    'num_samples':1000, 
    'progress_bar':False,
}

#generate AR1 process
N = 200
scale = 0.64
z = np.random.RandomState(8).normal(scale=scale, size=N)
theta = -0.1
y = np.zeros_like(z)
y[0] = z[0]
y[1:] = z[:-1]*theta + z[1:]
y = jnp.array(y, dtype=jnp.float32)


def model4(N, N_future, y):
    scale0 = numpyro.sample('scale0', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(0, 1))
    
    def transition(carry, t):
        scale_prev, y_prev = carry
        loc = y_prev*theta
        scale_curr = jnp.exp(jnp.log(scale_prev) + beta*y_prev)
        
        with numpyro.handlers.reparam(config={"y": TransformReparam()}):
            y_ = numpyro.sample('y', 
                                dist.TransformedDistribution(
                                    dist.Normal(loc=loc, scale=1.), 
                                    AffineTransform(0, scale_curr)
                                )
                               )
        return (scale_curr, y_), scale_curr

    init_0 = scale0, y[0]
    transform = AffineTransform(0, scale0)
    y_base = transform.inv(y)
    with numpyro.handlers.condition(data={'y_base': y_base}):
        _, scale_t = scan(transition, init_0, jnp.arange(N + N_future), )
    numpyro.factor("logdet", -1 * transform.log_abs_det_jacobian(y_base, y))
    
    numpyro.deterministic('scale_t', scale_t)
    
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])
        
        
        
hmc = MCMC(NUTS(model4, target_accept_prob=0.95, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata, var_names=['scale0', 'theta', 'beta'])

The above runs however it is wrong. Within the scan the AffineTransform is using scale_curr whereas outside it is using scale0. The scale_curr is collected into scale_t variable however I need the transform before the scan not after.

Is there a way around this restriction? I guess it would require #2878 which I see is still open.

Using reparam, in this case, seems to be tricky. How about using TFP’s TransformedDistribution without reparam instead? I believe TFP bijectors are jittable.

Thanks for checking, though not sure I follow what to do exactly. I tried the below but it failed

from tensorflow_probability.substrates.jax import bijectors as tfb
import numpyro.contrib.tfp.distributions as tfd

def model5(N, N_future, y):
    scale0 = numpyro.sample('scale0', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(0, 1))
    
    def transition(carry, t):
        scale_prev, y_prev = carry
        loc = y_prev*theta
        scale_curr = jnp.exp(jnp.log(scale_prev) + beta*y_prev)
        
        y_ = numpyro.sample('y', 
                dist.TransformedDistribution(
                    dist.Normal(loc=loc, scale=1.), 
                    tfd.BijectorTransform(tfb.Chain([
                        tfb.Shift(0.), 
                        tfb.Scale(scale_curr)
                    ]))
                )
               )
        return (scale_curr, y_), scale_curr

    init_0 = scale0, y[0]
    with numpyro.handlers.condition(data={'y': y}):
        _, scale_t = scan(transition, init_0, jnp.arange(N + N_future), )
    
    numpyro.deterministic('scale_t', scale_t)
    
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])
hmc = MCMC(NUTS(model5, target_accept_prob=0.9, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata, var_names=['scale0', 'theta', 'beta'])

error message is:

NotImplementedError: Flatenning TransformedDistribution is only supported for some specific cases. Consider using TransformReparam to convert this distribution to the base_dist, which is supported in most situtations. In addition, please reach out to us with your usage cases.

which is the same error message I get if I use numpyro transforms rather than TFP bijectors

I also tried the below

from tensorflow_probability.substrates.jax import distributions as tfp_dist
def model6(N, N_future, y):
    scale0 = numpyro.sample('scale0', dist.Exponential(1))
    theta = numpyro.sample('theta', dist.Normal(0, 1))
    beta = numpyro.sample('beta', dist.Normal(0, 1))
    
    def transition(carry, t):
        scale_prev, y_prev = carry
        loc = y_prev*theta
        scale_curr = jnp.exp(jnp.log(scale_prev) + beta*y_prev)
        
        y_ = numpyro.sample('y', 
                tfp_dist.TransformedDistribution(
                    tfp_dist.Normal(loc=loc, scale=1.), 
                    tfd.BijectorTransform(tfb.Chain([
                        tfb.Shift(0.), 
                        tfb.Scale(scale_curr)
                    ]))
                )
               )
        return (scale_curr, y_), scale_curr

    init_0 = scale0, y[0]
    with numpyro.handlers.condition(data={'y': y}):
        _, scale_t = scan(transition, init_0, jnp.arange(N + N_future), )
    
    numpyro.deterministic('scale_t', scale_t)
    
    if N_future > 0:
        numpyro.deterministic("y_forecast", ys[-N_future:])
hmc = MCMC(NUTS(model6, target_accept_prob=0.9, init_strategy=init_to_median), **sample_params)
hmc.run(random.PRNGKey(0), extra_fields=("z", "energy", "diverging"), N=N, y=y, N_future=0)
idata = az.from_numpyro(hmc, )
az.summary(idata, var_names=['scale0', 'theta', 'beta'])

but this failed with message:
'BijectorTransform' object has no attribute 'name'