Sampling over possible data splits for change point detection

I’m writing an inference to do change point detection. I need to randomly split the data

import numpy

goals_0 = numpy.random.normal(1, 1, size=(1000,))
goals_1 = numpy.random.normal(2, 1, size=(1000,))
goals = numpy.concatenate((goals_0, goals_1))



def model(vars_dict, goals, validate_args):
    T = numpyro.sample("T", dist.Uniform(0, len(goals), validate_args=validate_args))
    num_points = len(goals)
    with numpyro.plate('num_points', num_points) as index:
        for i in range(len(index)):
            if index[i] < T and index[i] + 1 > T:
                label_t0 = 't0_%d'%index[i]
                goals_t0 = goals[:index[i]]
                model0(label_t0, vars_dict, goals_t0, validate_args)
                
                label_t1 = 't1_%d'%index[i]
                goals_t1 = goals[index[i]:]
                model0(label_t1, vars_dict, goals_t1, validate_args)


This gives me a plethora of errors.

What’s the easiest way to sample over splits of the data?

Have you tried poutine.mask. You can generate mask for one part and then inverse it for the other. You can generate the mask based on the realisation of T.

Hi @pavleb, thx for the response. I guess that’s really at the heart of my question, how do you use the realisation of T to set the mask?

T = numpyro.sample("T", dist.Uniform(0, len(goals) -1, validate_args=validate_args))
num_points = len(goals)
goals_mask = numpy.ones(num_points, dtype=numpy.bool_)
goals_mask[:T] = 0

does not work b/c T is an abstract tracer in numpyro. Any help is appreciated.

You are right, but when you do inference you will provide a random generator key from JAX and this will trigger samples. Here is a small example:

import numpyro
import numpyro.distributions as dist
from jax import random

from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.distributions import constraints
from numpyro.infer import Predictive, SVI, Trace_ELBO

rng_key = random.PRNGKey(0)


def model(vars_dict, goals, validate_args):
    T = numpyro.sample("T", dist.Uniform(0, len(goals)))
    print(T)

def guide(vars_dict, goals, validate_args):
    pass

data = jnp.concatenate([jnp.ones(6), jnp.zeros(4)])
optimizer = numpyro.optim.Adam(step_size=0.0005)
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
svi_result = svi.run(random.PRNGKey(0), 1, {}, data, None)

This will print the value of T, which means that you can use it when making your mask.
Furthermore, do not use numpy focus on jax.numpy. Check the sharp bits here.

In my example above I use SVI but the same is valid for MCMC.
Beware that T is float so indexing should be done after conversion.

Hi @pavleb, THANK YOU SO MUCH. using mask.at[T].set(0) worked perfectly. my problem was really a jax problem. thanks again for the guidance.