Hello friends,
I am very new at this so any help is highly appreciated.
Task
I want to find which time offset gives me a good match between the time vs. latitude of a series of solar cycles given the time vs. latitude distribution of a series of reference solar cycles.
Approach
I’m not sure if this is the right approach, but my idea was to create a model that marginalizes the distribution of offsets by comparing the latitude vs. time points of a series of solar cycles against the empirical distribution of a series of solar cycles. I am using TFP’s Empirical distribution
Problem
I was thinking I could use two plates, one to represent the fitted cycles and one to represent the reference cycles, but I’m not being able to figure out the broadcasting while using TFP’s empirical distribution.
Here is my model:
def model(self,
ref_lat:jnp.array,
ref_time:jnp.array,
fit_lat:jnp.array,
fit_time:jnp.array):
ref_cyc_n = ref_lat.shape[1]
fit_cyc_n = fit_lat.shape[1]
# Broadcast ref and fit variables
ref_lat_b = jnp.swapaxes(jnp.swapaxes(jnp.broadcast_to(ref_lat, (fit_lat.shape[1], ref_lat.shape[0], ref_cyc_n)),0, 1), 1, 2)
ref_time_b = jnp.swapaxes(jnp.swapaxes(jnp.broadcast_to(ref_time, (fit_lat.shape[1], ref_lat.shape[0], ref_cyc_n)),0, 1), 1, 2)
ref_data = jnp.concatenate([ref_time_b[:,None,:,:], ref_lat_b[:,None,:,:]], axis=1)
fit_lat_b = jnp.swapaxes(jnp.broadcast_to(fit_lat, (ref_cyc_n, fit_lat.shape[0], fit_cyc_n)), 0, 1)
fit_time_b = jnp.swapaxes(jnp.broadcast_to(fit_time, (ref_cyc_n, fit_lat.shape[0], fit_cyc_n)), 0, 1)
# Define plates
fit_plate = numpyro.plate("fit_cyc", fit_cyc_n, dim=-1)
ref_plate = numpyro.plate("ref_cyc", ref_cyc_n, dim=-2)
with fit_plate:
with ref_plate:
offst = numpyro.sample('offst', dist.Normal(0., self.offst_sig))
offst_time = fit_time_b + offst
print(offst_time.shape)
fit_data = jnp.concatenate([offst_time[:,None,:,:], fit_lat_b[:,None,:,:]], axis=1)
print(fit_data.shape)
numpyro.sample('obs', tfd.Empirical(ref_data, event_ndims=3), obs=fit_data)
I don’t know if it is the right thing, but I start by broadcasting the fitted cycle data and the reference cycle data so that they have the following dimensions:
[Number of latitude-time pairs, number of reference cycles, number of fitted cycles]
But since the empirical distribution has to be two dimensional (time, latitude), I thought I needed to concatenate latitude and time like in the line:
fit_data = jnp.concatenate([offst_time[:,None,:,:], fit_lat_b[:,None,:,:]], axis=1)
The problem is that when I initialize the model and try to plot it, I get the message:
ValueError: Incompatible shapes for broadcasting: shapes=[(4, 5), (9680,)]
Here I was trying to fit 5 cycles against 4 reference cycles and each cycle has 9680 points. Any help is super-appreciated!!