Help, how can I use a 2D empirical distribution?

Hello friends,

I am very new at this so any help is highly appreciated.

Task

I want to find which time offset gives me a good match between the time vs. latitude of a series of solar cycles given the time vs. latitude distribution of a series of reference solar cycles.

Approach
I’m not sure if this is the right approach, but my idea was to create a model that marginalizes the distribution of offsets by comparing the latitude vs. time points of a series of solar cycles against the empirical distribution of a series of solar cycles. I am using TFP’s Empirical distribution

Problem
I was thinking I could use two plates, one to represent the fitted cycles and one to represent the reference cycles, but I’m not being able to figure out the broadcasting while using TFP’s empirical distribution.

Here is my model:

def model(self, 
            ref_lat:jnp.array, 
            ref_time:jnp.array, 
            fit_lat:jnp.array, 
            fit_time:jnp.array):

    ref_cyc_n = ref_lat.shape[1]
    fit_cyc_n = fit_lat.shape[1]

    # Broadcast ref and fit variables
    ref_lat_b = jnp.swapaxes(jnp.swapaxes(jnp.broadcast_to(ref_lat, (fit_lat.shape[1], ref_lat.shape[0], ref_cyc_n)),0, 1), 1, 2)
    ref_time_b = jnp.swapaxes(jnp.swapaxes(jnp.broadcast_to(ref_time, (fit_lat.shape[1], ref_lat.shape[0], ref_cyc_n)),0, 1), 1, 2)
    ref_data = jnp.concatenate([ref_time_b[:,None,:,:], ref_lat_b[:,None,:,:]], axis=1)

    fit_lat_b = jnp.swapaxes(jnp.broadcast_to(fit_lat, (ref_cyc_n, fit_lat.shape[0], fit_cyc_n)), 0, 1)
    fit_time_b = jnp.swapaxes(jnp.broadcast_to(fit_time, (ref_cyc_n, fit_lat.shape[0], fit_cyc_n)), 0, 1)

    # Define plates
    fit_plate = numpyro.plate("fit_cyc", fit_cyc_n, dim=-1)
    ref_plate = numpyro.plate("ref_cyc", ref_cyc_n, dim=-2)
    with fit_plate:
        with ref_plate:
            offst = numpyro.sample('offst', dist.Normal(0., self.offst_sig))
            offst_time = fit_time_b + offst
            print(offst_time.shape)
            fit_data = jnp.concatenate([offst_time[:,None,:,:], fit_lat_b[:,None,:,:]], axis=1)
            print(fit_data.shape)
            
            numpyro.sample('obs', tfd.Empirical(ref_data, event_ndims=3), obs=fit_data)

I don’t know if it is the right thing, but I start by broadcasting the fitted cycle data and the reference cycle data so that they have the following dimensions:

[Number of latitude-time pairs, number of reference cycles, number of fitted cycles]

But since the empirical distribution has to be two dimensional (time, latitude), I thought I needed to concatenate latitude and time like in the line:

fit_data = jnp.concatenate([offst_time[:,None,:,:], fit_lat_b[:,None,:,:]], axis=1)

The problem is that when I initialize the model and try to plot it, I get the message:

ValueError: Incompatible shapes for broadcasting: shapes=[(4, 5), (9680,)]

Here I was trying to fit 5 cycles against 4 reference cycles and each cycle has 9680 points. Any help is super-appreciated!!

Just in case anybody else is dealing with a similar issue. In the end I was able to do what I want (not 100% sure, but it’s running) using a categorical distribution. I calculate the probabilities using a vectorized histogram on the categorical indices, and these in turn come from the mapping of my Latitude-time points into a square grid using time_lat_to_categorical.

I had a problem with my event dimensions which I solved by expanding the empirical distribution and setting the final dimension as an event dimension. Here is my model. I will update it, if I find (and solve more problems) as I go

def obs_plate_model(self, 
            ref_lat:jnp.array, 
            ref_time:jnp.array, 
            ref_area:jnp.array, 
            fit_lat:jnp.array, 
            fit_time:jnp.array):

    ref_cyc_n = ref_lat.shape[0]
    fit_cyc_n = fit_lat.shape[1]

    ref_lat_b = jnp.expand_dims(jnp.expand_dims(ref_lat, 1), 1)
    ref_time_b = jnp.expand_dims(jnp.expand_dims(ref_time, 1), 1)
    ref_area_b = jnp.expand_dims(jnp.expand_dims(ref_area, 1), 1)

    fit_lat_b = jnp.expand_dims(fit_lat, 0)
    fit_time_b = jnp.expand_dims(fit_time, 0)

    # Define plates
    fit_plate = numpyro.plate("fit_cyc_pl", fit_cyc_n, dim=-1)
    hem_plate = numpyro.plate("hem_pl", 2, dim=-2)
    ref_plate = numpyro.plate("ref_cyc_pl", ref_cyc_n, dim=-3)

    with fit_plate:
        threshold = numpyro.sample('threshld', dist.TruncatedNormal(0., self.thres_sig, low=0))

        # # Calculate empirical distribution
        ref_lat_tr = jnp.where(ref_area_b<jnp.expand_dims(threshold,1), -999, ref_lat_b)
        ref_time_tr = jnp.where(ref_area_b<jnp.expand_dims(threshold,1), -999, ref_time_b)

        ref_data = self.time_lat_to_categorical(ref_time_tr, ref_lat_tr)
        ref_prob = vectorized_hist(ref_data, self.ctgrcl_edges)
        ref_prob = ref_prob.at[:,:,:,0].set(0) + 1e-20
        ref_prob = ref_prob/jnp.expand_dims(ref_prob.sum(axis=-1), -1)            
                        
        with hem_plate:
            with ref_plate:

                offst = numpyro.sample('offst', dist.Normal(0., self.offst_sig))
                offst_time = fit_time_b + jnp.expand_dims(offst,3)
                fit_data = self.time_lat_to_categorical(offst_time, fit_lat_b)

                dist_tmp = dist.Categorical(probs=jnp.expand_dims(ref_prob, 3)).expand((ref_cyc_n, 1, fit_cyc_n,fit_data.shape[-1])).to_event(1)

                numpyro.sample(f'obs', dist_tmp, obs=fit_data)