Missing a plate statement on batch dimension on my sample dimension?

Hello friends, I’m starting with numpyro and I’m getting a message that I’m not sure how to fix. I have been working on understanding my sample, batch, and event shapes, but I’m getting a warning when I train my model. When I do a trace this is what I get:

   Trace Shapes:             
    Param Sites:             
   Sample Sites:             
fit_cyc_pl plate          1 |
    hem_pl plate          2 |
ref_cyc_pl plate          6 |
   threshld dist          1 |
           value          1 |
      offst dist      6 2 1 |
           value      6 2 1 |
        obs dist      6 2 1 |
           value 9853 6 2 1 |

Note that the number 9853 correspond to the number of observations that I’m fitting. The other dimensions are associated with the plates I define: 1 cycle that I’m fitting, 2 hemispheres in that cycle, and I’m comparing it against 6 reference cycles. I was kind of assuming that numpyro would understand that 9853 is the sample shape, but when I train my model I get this warning:

UserWarning: Missing a plate statement for batch dimension -4 at site 'obs'.

Am I incorrect when I think that my leftmost dimension is my sample dimension and that it should not be visible in the trace?

Thank you for your kindness if you can help me.

Here is my model, in case it helps:

def plate_model_autobroadcast(self, 
            ref_lat:jnp.array, 
            ref_time:jnp.array, 
            ref_area:jnp.array, 
            fit_lat:jnp.array, 
            fit_time:jnp.array):

    ref_cyc_n = ref_lat.shape[-1]
    fit_cyc_n = fit_lat.shape[-1]

    # # Broadcast ref and fit variables
    ref_lat_b = jnp.expand_dims(jnp.expand_dims(ref_lat, 2), 3)
    ref_time_b = jnp.expand_dims(jnp.expand_dims(ref_time, 2), 3)
    ref_area_b = jnp.expand_dims(jnp.expand_dims(ref_area, 2), 3)

    fit_lat_b = jnp.expand_dims(fit_lat, 1)
    fit_time_b = jnp.expand_dims(fit_time, 1)

    # Define plates
    fit_plate = numpyro.plate("fit_cyc_pl", fit_cyc_n, dim=-1)
    hem_plate = numpyro.plate("hem_pl", 2, dim=-2)
    ref_plate = numpyro.plate("ref_cyc_pl", ref_cyc_n, dim=-3)

    with fit_plate:
        threshold = numpyro.sample('threshld', dist.TruncatedNormal(0., self.thres_sig, low=0))

        # # Calculate empirical distribution
        ref_lat_tr = jnp.where(ref_area_b<threshold, -999, ref_lat_b)
        ref_time_tr = jnp.where(ref_area_b<threshold, -999, ref_time_b)

        ref_data = self.time_lat_to_categorical(ref_time_tr, ref_lat_tr)
        ref_prob = jnp.moveaxis(vectorized_hist(jnp.moveaxis(ref_data, [0,1,2,3], [3,0,1,2]), self.ctgrcl_edges), [0,1,2,3], [1,2,3,0])
        ref_prob = ref_prob.at[0,:,:].set(0) + 1e-17
        ref_prob = ref_prob/jnp.broadcast_to(ref_prob.sum(axis=0), ref_prob.shape)

        with hem_plate:
            with ref_plate:
                offst = numpyro.sample('offst', dist.Normal(0., self.offst_sig))
                offst_time = fit_time_b + offst
                fit_data = self.time_lat_to_categorical(offst_time, fit_lat_b)

                numpyro.sample(f'obs', dist.Categorical(probs=jnp.moveaxis(ref_prob, [0,1,2,3], [3, 0, 1, 2])), obs=fit_data)

it should not be visible in the trace?

In a pyro model, we consider sample dimension as batch dimension. So you will need to use plate notation for it.

Thank you @fehiepsi and thank you soo much for all the development you do. I actually feel quite honored that you are helping me.

I think I was able to solve the problem by better understanding dimensions and how they relate to my problem. What I did in the end was to expand my empirical categorical distribution and add .to_event() so that the orphan dimension I had was taken as event shape. I’m still not sure I am doing everything correctly, but I definitely have learned a lot. Here is my new trace:

  Trace Shapes:                
    Param Sites:                
   Sample Sites:                
fit_cyc_pl plate       1 |      
    hem_pl plate       2 |      
ref_cyc_pl plate      34 |      
   threshld dist       1 |      
           value       1 |      
      offst dist 34 2  1 |      
           value 34 2  1 |      
        obs dist 34 2  1 | 13636
           value 34 2  1 | 13636

I’m no longer getting the warning. And here is my modified model.

def obs_plate_model(self, 
            ref_lat:jnp.array, 
            ref_time:jnp.array, 
            ref_area:jnp.array, 
            fit_lat:jnp.array, 
            fit_time:jnp.array):

    ref_cyc_n = ref_lat.shape[0]
    fit_cyc_n = fit_lat.shape[1]

    ref_lat_b = jnp.expand_dims(jnp.expand_dims(ref_lat, 1), 1)
    ref_time_b = jnp.expand_dims(jnp.expand_dims(ref_time, 1), 1)
    ref_area_b = jnp.expand_dims(jnp.expand_dims(ref_area, 1), 1)

    fit_lat_b = jnp.expand_dims(fit_lat, 0)
    fit_time_b = jnp.expand_dims(fit_time, 0)

    # Define plates
    fit_plate = numpyro.plate("fit_cyc_pl", fit_cyc_n, dim=-1)
    hem_plate = numpyro.plate("hem_pl", 2, dim=-2)
    ref_plate = numpyro.plate("ref_cyc_pl", ref_cyc_n, dim=-3)

    with fit_plate:
        threshold = numpyro.sample('threshld', dist.TruncatedNormal(0., self.thres_sig, low=0))

        # # Calculate empirical distribution
        ref_lat_tr = jnp.where(ref_area_b<jnp.expand_dims(threshold,1), -999, ref_lat_b)
        ref_time_tr = jnp.where(ref_area_b<jnp.expand_dims(threshold,1), -999, ref_time_b)

        ref_data = self.time_lat_to_categorical(ref_time_tr, ref_lat_tr)
        ref_prob = vectorized_hist(ref_data, self.ctgrcl_edges)
        ref_prob = ref_prob.at[:,:,:,0].set(0) + 1e-20
        ref_prob = ref_prob/jnp.expand_dims(ref_prob.sum(axis=-1), -1)            
                        
        with hem_plate:
            with ref_plate:

                offst = numpyro.sample('offst', dist.Normal(0., self.offst_sig))
                offst_time = fit_time_b + jnp.expand_dims(offst,3)
                fit_data = self.time_lat_to_categorical(offst_time, fit_lat_b)

                dist_tmp = dist.Categorical(probs=jnp.expand_dims(ref_prob, 3)).expand((ref_cyc_n, 1, fit_cyc_n,fit_data.shape[-1])).to_event(1)

                numpyro.sample(f'obs', dist_tmp, obs=fit_data)