User distribution sampling

Dear expert,

From the class below

class MixtureModel_jax():
    def __init__(self, locs, scales, weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loc = jnp.array([locs]).T
        self.scale = jnp.array([scales]).T
        self.weights = jnp.array([weights]).T
        norm = jnp.sum(self.weights)
        self.weights = self.weights/norm

        self.num_distr = len(locs)

    def pdf(self, x):
        probs = jax.scipy.stats.norm.pdf(x,loc=self.loc, scale=self.scale)
        return jnp.dot(self.weights.T,probs).squeeze()
        
    def logpdf(self, x):
        log_probs = jax.scipy.stats.norm.logpdf(x,loc=self.loc, scale=self.scale)
        return jax.scipy.special.logsumexp(jnp.log(self.weights) + log_probs, axis=0)

mixture_gaussian_model = MixtureModel_jax([0,1.5],[0.5,0.1],[8,2])

I’m able to generate samples using TansorFlow

import jax
import jax.numpy as jnp

from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
tfb = tfp.bijectors

num_results = int(1e5)
num_burnin_steps = int(1e3)

adaptive_hmc = tfp.mcmc.SimpleStepSizeAdaptation(
    tfp.mcmc.HamiltonianMonteCarlo(
        target_log_prob_fn=mixture_gaussian_model.logpdf,  # Here we use the log prob of our known distribtion 
        num_leapfrog_steps=3,
        step_size=1.),
    num_adaptation_steps=int(num_burnin_steps * 0.8))

hmc_samples, is_accepted = tfp.mcmc.sample_chain(
      num_results=num_results,
      num_burnin_steps=num_burnin_steps,
      current_state=jnp.zeros([2]),
      kernel=adaptive_hmc,
      trace_fn=lambda _, pkr: pkr.inner_results.is_accepted,
      seed=jax.random.PRNGKey(1))

But I cannot manage to transpose to NumpyRo code, is-it possible ?
Thanks

In NumPyro, you can use

kernel = NUTS(potential_fn=lambda x: -mixture_gaussian_model.logpdf(x))
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(PRNGKey(0), init_params=jnp.zeros(2))

Thanks @fehiepsi although I first get an error and I had to change

my class/logpdf code to get a scalar

   def logpdf(self, x):
        log_probs = jax.scipy.stats.norm.logpdf(x,loc=self.loc, scale=self.scale)
        return jax.scipy.special.logsumexp(jnp.log(self.weights) + log_probs, axis=0)[0]

and I manage to run this code

# Start from this source of randomness. We will split keys for subsequent operations.
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)

# Run NUTS.
kernel = NUTS(potential_fn=lambda x: -mixture_gaussian_model.logpdf(x))
num_samples = 10000
mcmc = MCMC(kernel, num_warmup=1000, num_samples=num_samples, num_chains=2)
mcmc.run(rng_key_, init_params=jnp.zeros(2))
mcmc.print_summary()
samples_1 = mcmc.get_samples()

And get this 1D distrib.

image

Fine ! do you know how I can modify my code to get a 2D distributions (x,y) with x and y following my 1D distrib. ?

For instance with TF I got
image

Have you an idea ? Thanks

I think you can use similar code as above. In TFP, you use

target_log_prob_fn=mixture_gaussian_model.logpdf

but in NumPyro, it will be

potential_fn=lambda x, y, z, t: -mixture_gaussian_model.logpdf(x, y, z, t).sum()

i.e. potential_fn is the negative of target_log_prob_fn.

In addition, you need to set init_params in NumPyro to be the same as current_state in TFP.

Thanks @fehiepsi but

  1. I do not manage to correctly initialise ìnit_params with 1 chain and x,y inputs variables, and neither several chain
  2. I’m quite surprised that there is no a TFP similar mechanism where I do not need to write explicitly all the variables as the number can be 10 or so.

For multiple chains, you just need to provide a batch of init values. E.g. instead of init_params=jnp.zeros(2), you can provide init_params=jnp.zeros((4, 2)) to specify init values for 4 chains.

there is no a TFP similar mechanism where I do not need to write explicitly all the variables as the number can be 10 or so

Could you be more explicit?

Well, I realized that I had a miss interpretation the 2D plots produced with TF that I have posted. It is the result of the samples of the 1st chain against the 2nd chain if I’m right.

Now, would like to sample a real 2D pdf based on my 1D mixture_gaussian_model with 3 chains for instance (ie. proper initialisation).
Thanks for your kind help.