Bivariate distribution from two independent random variables

Hey, I would like to implement a GeneralMixture model on two-dimensional data.

My models work just fine when using the builtin multidimensional distributions like MultivariateNormal or BivariateVonMises. But I also would like to use just two univariate distributions that are independent of each other, in addition to that (so have both in one model).

Q: What I don’t know though is how I would model a bivariate distribution (dist_IND) composed out of two independent random variables, that sample from univariate builtin distributions (like a Normal (dist_Ix and an Exponential (dist_Iy)).

I guess what would work is to just manually implement a bivariate distribution that samples from the two univariate distributions individually and puts them together. But maybe there is sth builtin, that I did not find so far?

What I found was the Independent class, but from what I understood this is only used to basically reduce the correlation matrix of a MultivariateNormal distribution.

Thanks in advance.
P.S. This is my first post, so sorry if this is not the best description, but don’t hesitate to ask for any further information, I am more than willing to do so.

Some not working code:

# shape of data would be: (num_samples, 2)
def general_mixture_independent(data, num_components):
    weights = numpyro.sample('weights', dist.Dirichlet(jnp.ones(num_components)))
    
    # multivariate builtin distribution
    mu_MVN = numpyro.sample('mu_MVN', dist.MultivariateNormal(jnp.zeros(2), jnp.eye(2)))
    dist_MVN = dist.MultivariateNormal(mu_MVN, jnp.eye(2))
    
    # two univariate builting distributions
    # x-dimension
    mu_x = numpyro.sample('mu_x', dist.Normal(0, 1))
    sigma_x = numpyro.sample('sigma_x', dist.HalfNormal(1))
    dist_Ix = dist.Normal(mu_x, sigma_x)
    
    # y-dimension
    lambda_y = numpyro.sample('lambda_y', dist.HalfNormal(3))
    dist_Iy = dist.Exponential(lambda_y)
    
    # put together -> that's basically what I don't know how to do
    # and it of course does not work like that
    dist_IND = jnp.column_stack([dist_Ix, distIy])
    mixture = dist.MixtureGeneral(dist.Categorical(probs=weights), [dist_MVN, dist_IND], obs=data)

the easiest way to do this is MixtureSameFamily

I think I don’t quite understand. In the MixtureSameFamily description it says:

The MixtureSameFamily distribution implements a (batch of) mixture distribution where all component are from different parameterizations of the same distribution type

In my example I don’t have the same distribution type, as the MultivariateNormal is a different type than a bivariate of two independent univariates e.g. an Exponential and a Normal. Or do I misunderstand sth?

Thank you!

if you want to use MixtureGeneral then afaik the second argument needs to be a list of distribution objects. if you want those distribution objects to have various batch dimensions you should pass parameters with appropriate shape when instantiating those distribution objects

afaik i don’t think you can easily define a mixture distribution where individual component distributions are themselves composed of products of different distribution families (if that’s what you wanted to do)

Okay, then at least I did not oversee anything easily.

I am not bound to a MixtureGeneral to be honest, I just want to fit sth like the following mixture model:

from numpyro import distributions as dist
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

key = random.PRNGKey(42)

# Bivariate von Mises distribution
samples_vm = dist.SineBivariateVonMises(2, 2, 4, 2, -2.5).sample(key, (1500, ))

# Bivariate distribution of two independent univariate distributions (von mises, exp (with shift))
samples_vm_1d = dist.VonMises(0.5, 1).sample(key, (1000, ))
samples_exp = dist.Exponential(3).sample(key, (1000, )) - jnp.pi # maybe this shift should be considered (but maybe I can also use another distribution)

samples_bivInd = jnp.column_stack([samples_vm_1d, samples_exp])

# Mixture model
samples = jnp.concatenate([samples_vm, samples_bivInd])

Density plots:

Do you see another way than a GeneralMixture?
Thanks a lot!

are you going to be doing HMC inference or something else?

Yes HMC was on my mind.

so one option would be to define a custom VonMisesPlusExponential distribution and then use MixtureGeneral to combine that with SineBivariateVonMises

another is to use a factor statement where you explicitly compute the log_prob you’re interested in, something vaguely like

def model():
    my_log_prob1a = dist.VonMises(...).log_prob(data1)
    my_log_prob1b = dist.Exponential(...).log_prob(data1)
    my_log_prob2 = dist.SineBivariateVonMises(...).log_prob(data2)
    # define mixture log prob using logsumexp etc
    my_log_prob = f(my_log_prob1a, my_log_prob1b, my_log_prob2) 
    numpyro.factor("my_log_prob", my_log_prob)

Thank you very much. I decided to write a custom distribution which worked out quite nice and I can run my code with it already.

I still have one question though and that is with regard to the support, that is set as a static class variable for all distributions. My problem is that my support is different for the two dimensions (VonMises on x-dim and shifted Exponential on y-dim).

I defined a ShiftedExponential distribution that shifts the exponential by -pi, that I use as a distribution for the y-dim.
So my support would be [-pi, pi] x [-pi, Inf). Is there a way to model this? Currently I just have:

support = constraints.independent(constraints.circular, 1)

Thank you very much

i’m not sure if this matter is the random variable is used as a likelihood (i.e. the random variable is observed). which is true in your case? @fehiepsi would know better

Yes, support is unnecessary for the likelihood. We don’t have an api for Distribution with joint support. If you are using HMC, you can use enumeration as in this example.

I don’t think I quite understand how we got from support to enumeration. What exactly would I use that for?

For the above model, you can rewrite MixtureDistribution into something like

index = sample(Categorical(weight))
sample(d1.mask(index == 0), obs=...)
sample(d2.mask(index == 1), obs=...)

where d1, d2 are component distributions.
As mentioned above, you can use MixtureDistribution because your variable is observed. You might want to add additional logic to mask out the observations which do not belong to the support, something like

d1 = d1.mask(d1.support(data))
d2 = d2.mask(d2.support(data))

before feeding them into the mixture distribution.

I think my problem is first of all to understand why all of this would be necessary. So what I mean is what is the advantage over using a General Mixture?

Maybe I’ll share my current working code:

@numpyro.handlers.reparam(
    config={"phi_loc_dep": CircularReparam(), 
            "psi_loc_dep": CircularReparam(),
            "phi_loc_dep2": CircularReparam(), 
            "psi_loc_dep2": CircularReparam(),
            "phi_loc_vmExp": CircularReparam(),
            "phi_loc_park": CircularReparam(),
            "psi_loc_park": CircularReparam(),
           }
)

# ensure that order of samples is: arrival, parking, departure
def tst3D(samples):
    # cluster weights
    weights = sample('weights', Dirichlet(jnp.ones(4)))
    # Independent Departure Time/ over-night cluster
    phi_loc_dep = sample('phi_loc_dep', VonMises(1, 10))
    psi_loc_dep = sample('psi_loc_dep', VonMises(-1, 10))
    
    phi_conc_dep = sample('phi_conc_dep', Beta(2, 1))
    psi_conc_dep = sample('psi_conc_dep', Beta(1, 1))
    
    depInd = BivariateDepartureIndependentVonMises(phi_loc_dep, psi_loc_dep, 70 * phi_conc_dep, 70 * psi_conc_dep)
    
    # Independent Departure Time/ over-night cluster
    phi_loc_dep2 = sample('phi_loc_dep2', VonMises(1.5, 10))
    psi_loc_dep2 = sample('psi_loc_dep2', VonMises(-0.75, 10))
    
    phi_conc_dep2 = sample('phi_conc_dep2', Beta(1, 1))
    psi_conc_dep2 = sample('psi_conc_dep2', Beta(2, 1))
    
    depInd2 = BivariateDepartureIndependentVonMises(phi_loc_dep2, psi_loc_dep2, 70 * phi_conc_dep2, 70 * psi_conc_dep2)
    
    
    # Independent VM and Exp / short time parker
    phi_loc_vmExp = sample('phi_loc_vmExp', VonMises(1.25, 10))
    phi_conc_vmExp = sample('phi_conc_vmExp', Beta(1, 1))
    exp_rate = sample('exp_rate', Beta(10, 10))
    
    ExpVmInd = IndependentVonMisesExponential3D(phi_loc_vmExp, phi_conc_vmExp, 30 * exp_rate)
    
    # Arrival time - parking duration cluster / during the day parker
    phi_loc_park = sample('phi_loc_park', VonMises(0.3, 0.1))
    psi_loc_park = sample('psi_loc_park', VonMises(-2.5, 0.1))
    
    phi_conc_park = sample('phi_conc_park', Beta(1, 1))
    psi_conc_park = sample('psi_conc_park', Beta(1, 1))
    
    corr = sample('corr', Beta(1, 1))
    
    ArrPark = BivariateParkingVonMises3D(phi_loc_park, psi_loc_park, 70 * phi_conc_park, 70 * phi_conc_park, corr)

    # Mixture
    dists = [depInd, depInd2, ExpVmInd, ArrPark]
    mixture = MixtureGeneral(Categorical(probs=weights), dists)
    
    with numpyro.plate('samples', len(samples)):        
        sample("obs", mixture, obs=samples)

There is no advantage in your usage case.

1 Like

Thank you very much @martinjankowiak and @fehiepsi, you helped me a lot!
The thread could be closed, I just don’t know how :see_no_evil: