Mini-batching a particular (many-image) problem structure?

I am astronomer trying to measure the shapes of stars in near-infrared imaging. I have about 3.5 million images with 50–600 stars in each, so runtime is important to me. : )

I have cutouts around well-behaved stars, so that my data are a 3d array with shape (num_stars, num_xpix, num_ypix), and the stars are always in or very close to the central pixel. As well as fitting the precise location, flux, and sky/background level for each star, i want to fit a set of 4 parameters that describe the observed shape of all of these stars.

I’m using SVI to do the actual optimisation/regression. This is quite fast, but still not fast enough for me to finish the whole job within my HPC allocation. So i’m trying to get mini-batching working to speed convergence.

Here is a simplified version of my model:

def simple_moffat2d_fit( images, errors ):
    # NB. images/errors have shape [num_stars, num_xpix, num_ypix]
    # i.e. many images/errors all with the same shape.
    
    # define data geometry (fixed)
    num_images = images.shape[0]
    xx, yy = jnp.indices( images.shape[1:] )
    # define central pixel coordinate (used later)
    x0, y0 = images.shape[1]/2., images.shape[2]/2.

    # define PSF parameters (common to all stars)
    b = jnp.exp( numpyro.sample( 'lnb', dist.Normal( 0.7, 0.2 ) ) )
    # the b parameter controls the shape/concentration of the stars
    sx = jnp.exp( numpyro.sample( 'lnsigmax', dist.Normal( 1.0, 1.5 ) ) )  
    sy = jnp.exp( numpyro.sample( 'lnsigmay', dist.Normal( 1.0, 1.5 ) ) )
    # the sx/sy parameters control the x/y size of the stars
    rho = numpyro.sample( 'rho', dist.Normal( 0., 0.2 ) )
    # the rho parameter controls the ellpticity/orientation of the stars
    # (sx/y and rho are defined analogously to the multivariate normal)

    # this is the important bit -- vectors of parameters for each star
    with numpyro.plate( 'stars', num_images ):
        flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
        # logarithmic prior on the total flux
        bkg = numpyro.sample( 'skyflux', dist.Normal( 0, 10 ) ) 
        # normal prior on the local sky/background level
        xcen = numpyro.sample( f'xcenter', dist.Normal( x0, 2 ) ) 
        ycen = numpyro.sample( f'ycenter', dist.Normal( y0, 2 ) )
        # tight-ish prior on the x/y positions of each star

    # now compute the parametric models for each star
    # --- i want to move all this under the plate? -->
                                        
    # juggle indexing to get pixel coords relative to each centre 
    # NB. xx has shape [num_xpix, num_ypix] and xcen has shape [num_stars]
    xterm = ( xx[ None, ... ] - xcen[ :, None, None ] ) / sy                    
    yterm = ( yy[ None, ... ] - ycen[ :, None, None ] ) / sy
    r2 = (xterm*xterm + yterm*yterm -2.*rho*xterm*yterm) / (1. - rho**2.)
    # r2 now has shape [num_stars, num_xpix, num_ypix]

    # the model is defined according to r2 and ~normalised to unity
    norm = (b - 1.) / ( jnp.pi * sx * sy * jnp.sqrt( 1. - rho**2. ) )        
    moffat = norm * ( 1. + r2 )**(-b)

    # at last we have the scene as the sum of sky/background and models
    scene = flux[ :, None, None ] * moffat + bkg[ :, None, None ]
    # NB. scene has shape [num_stars, num_xpix, num_ypix]
    # i.e. same as observed images and errors

    # so finally we can evaluate the loglikelihood
    numpyro.sample( 'obs', dist.Normal( scene, errors ), obs=images ) 

What i understand from the SVI tutorials is that to get mini-batching working properly, i want to have a structure like this:

    # global parameters common to all stars
    b = jnp.exp( numpyro.sample( 'lnb', dist.Normal( 0.7, 0.2 ) ) )
    # etc.

    with numpyro.plate( 'stars', num_images, subsample_size=5 ) as ind:
        # parameters specific to each star
        flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
        # etc.; these should now have shape [subsample_size,]

        scene = flux[ :, None, None ] * moffat + bkg[ :, None, None ]
        # now this should have shape [subsample_size, num_xpix, num_ypix] ?

        # note use of ind, which i think matches errors/images shape to scene?
        numpyro.sample( 'obs', dist.Normal( scene, errors[ind] ), obs=images[ind] ) 

With print statements, i can see that scene, errors[ind] and images[ind] all have the same shape; i.e. [subsample_size, num_xpix, num_ypix]. But the numpyro.sample line fails with a ValueError, complaining of Incompatible hapes for broadcasting: shapes=[(5,), (5, 25, 25)]. (This is for subsample_size=5 and num_xpix=num_ypix=25.)

So … clearly i’m doing something wrong, but i cannot for the life of me find it??? Can anyone set me straight? You will literally save of order a CPU-decade’s worth of emissions … !!

Thanks!

it’s useful to pepper everything with shape assert statements. both to make sure you know what’s happening but also to communicate to the reader of this forum who isn’t familiar with your model. so try adding that everywhere and share again?

e.g. assert flux.shappe = ind.shape and assert scene.shape == ? and assert moffat.shape == ?

Thanks for the suggestion. Here is another go at trying to simply/clarify my presentation of the problem:

def simple_moffat2d_fit( images, errors, subsample_size=5 ):
    
    # define data geometry (fixed)
    (num_images, num_xpix, num_ypix) = images.shape
    # i.e. a stack of many image/error arrays all with the same pixel-shape.
    xx, yy = jnp.indices( (num_xpix, num_ypix) )
    x0, y0 = num_xpix/2., num_ypix/2.
    
    # define PSF parameters (common to all stars)
    b = jnp.exp( numpyro.sample( 'lnb', dist.Normal( 0.7, 0.2 ) ) ) 
    sx = jnp.exp( numpyro.sample( 'lnsigmax', dist.Normal( 1.0, 1.5 ) ) )
    sy = jnp.exp( numpyro.sample( 'lnsigmay', dist.Normal( 1.0, 1.5 ) ) )
    rho = numpyro.sample( 'rho', dist.Normal( 0., 0.2 ) )

    with numpyro.plate( 'stars', num_images, subsample_size=subsample_size ) as ind:

        # this is the important bit -- vectors of parameters for each star
        flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
        bkg = numpyro.sample( 'skyflux', dist.Normal( 0, 10 ) )
        xcen = numpyro.sample( f'xcenter', dist.Normal( x0, 2 ) ) 
        ycen = numpyro.sample( f'ycenter', dist.Normal( y0, 2 ) )
        
        # one value of each parameter for each star                                                                        
        assert flux.shape == bkg.shape == xcen.shape == ycen.shape == (subsample_size,)
    
        # the star profiles are radial; this evaluates elliptical R^2                                     
        xterm = ( xx[ None, ... ] - xcen[ :, None, None ] ) / sy                    
        yterm = ( yy[ None, ... ] - ycen[ :, None, None ] ) / sy
        r2 = (xterm*xterm + yterm*yterm -2.*rho*xterm*yterm) / (1. - rho**2.)

        # note for fixed pixel geometry, R2 varies per star depending on its x/ycen.
        assert xterm.shape == yterm.shape == r2.shape == (subsample_size, num_xpix, num_ypix)

        # now use the R2 values to compute model for each star
        norm = (b - 1.) / ( jnp.pi * sx * sy * jnp.sqrt( 1. - rho**2. ) )        
        moffat = norm * ( 1. + r2 )**(-b)
        scene = flux[ :, None, None ] * moffat + bkg[ :, None, None ]
        
        # scene has shape directly comparable to data/error arrays
        assert moffat.shape == (subsample_size, num_xpix, num_ypix)
        assert moffat.shape == errors[ind].shape == images[ind].shape

        # now evaluate the model    
        numpyro.sample( 'obs', dist.Normal( scene, errors[ind, ...] ), obs=images[ind, ...] ) 

When i try to mcmc.run with this model, i can see (with print statements) that it reaches the numpyro.sample line, but then dies with

TypeError: broadcast_shapes got incompatible shapes for broadcasting: (1, 1, 5), (5, 25, 25).

Tracing this back, i think this shows where things are going wrong:

in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
    700 map_args = (rng_key, init_state, init_params)
    701 if self.num_chains == 1:
--> 702     states_flat, last_state = partial_map_fn(map_args)
    703     states = jax.tree.map(lambda x: x[jnp.newaxis, ...], states_flat)

in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields, remove_sites)
    463 # Check if _sample_fn is None, then we need to initialize the sampler.
    464 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None):
--> 465     new_init_state = self.sampler.init(
    466         rng_key,
    467         self.num_warmup,
    468         init_params,
    469         model_args=args,
    470         model_kwargs=kwargs,
    471     )

in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    678 model_kwargs = {} if model_kwargs is None else model_kwargs
    679 substituted_model = substitute(
    680     seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
    681     substitute_fn=init_strategy,
    682 )
    683 (
    684     inv_transforms,
    685     replay_model,
    686     has_enumerate_support,
    687     model_trace,
--> 688 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)

So do i need to do something special to initialise the model/sampler if i then want to run with subsampling?

i thought you were using SVI. mcmc isn’t really compatible with mini-batching.

D’oh!! I had been testing with mcmc thinking to use trace_plots to see how things behaved.

I’ve shifted my testing to SVI, using the following pattern:

from numpyro.infer import SVI, Trace_ELBO, TraceMeanField_ELBO, autoguide

# define the essential elements of the SVI computation
guide = autoguide.AutoDelta(simple_moffat2d_fit)
optimizer = numpyro.optim.Adam( step_size=optim_step_size)
svi = SVI(simple_moffat2d_fit, guide, optimizer, 
          loss=TraceMeanField_ELBO(),)

# function args and kwargs
args = (images, stderrs)
kwargs = {'subsample_size': 5}

# initialise the SVI with subsampling: this fails! 
rng_key = PRNGKey(0)
svi.init( rng_key, *args, init_params=init_pars, **kwargs )

# then run in blocks of steps_per_iter and monitor convergence
svi_result = svi.run(rng_key, steps_per_iter, *args, **kwargs )
for niter in range(max_num_iters):
    svi_result = svi.run(rng_key, steps_per_iter, *args, **kwargs,
                         init_state=svi_result.state )

With this change, i run into a very similar if not exactly the same error:

ValueError: Incompatible shapes for broadcasting: shapes=[(5,), (5, 25, 25)]

in SVI.init(self, rng_key, init_params, *args, **kwargs)
    182 if init_params is not None:
    183     guide_init = substitute(guide_init, init_params)
--> 184 guide_trace = trace(guide_init).get_trace(*args, **kwargs, **self.static_kwargs)
    185 init_guide_params = {
    186     name: site["value"]
    187     for name, site in guide_trace.items()
    188     if site["type"] == "param"
    189 }

... skipping a few ...

in AutoGuide._setup_prototype(self, *args, **kwargs)
    153 rng_key = numpyro.prng_key()
    154 with handlers.block():
    155     (
    156         init_params,
    157         self._potential_fn_gen,
    158         postprocess_fn_gen,
    159         self.prototype_trace,
--> 160     ) = initialize_model(
    161         rng_key,
    162         self.model,
    163         init_strategy=self.init_loc_fn,
    164         dynamic_args=True,
    165         model_args=args,
    166         model_kwargs=kwargs,
    167     )

in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad)
    678 model_kwargs = {} if model_kwargs is None else model_kwargs
    679 substituted_model = substitute(
    680     seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]),
    681     substitute_fn=init_strategy,
    682 )
    683 (
    684     inv_transforms,
    685     replay_model,
    686     has_enumerate_support,
    687     model_trace,
--> 688 ) = _get_model_transforms(substituted_model, model_args, model_kwargs)
    690 for name, site in model_trace.items():
    691     if (
    692         site["type"] == "sample"
    693         and isinstance(site["fn"], dist.Delta)
    694         and not site["is_observed"]

So … not sure where to go next!

I was just running into this error earlier today lol. The subsample size comes from the plate (5), but your images / errors are indexed over the first dimension (but still have 2 remaining dimensions). In your sample statement, the dist.Normal() is expecting only a size of 5 (due to the subsample size in the plate) in all its parameters, but you’re passing in a (5, 25, 25) value. You should add nested plating so the size of scene matches the size of errors[ind].

for example:

with numpyro.plate( 'stars', num_images, subsample_size=5 ) as ind:
     with numpyro.plate( 'x', 25, dim=-2 ):
          with numpyro.plate( 'y', 25, dim=-1 ):
               ....

That should work?

Thanks, abhi — I’ve tried this, but it doesn’t seem to work. I get an assertion error that gets triggered at the with numpyro.plate( ‘y’ ) statement, which … i just don’t understand!??

Bearing in mind that it is only the xterm, yterm and lower variables that are three dimensional, i have also tried shifting those two new lines around, but with no luck. If only to try to test my understanding of how things should work, and hopefully also to help make clear the nature of the problem, i have tried to be explicit about the axes, like this:

def simple_moffat2d_fit( images, errors, subsample_size=10 ):
    
    # ...

    with numpyro.plate( 'stars', size=num_images, 
                        subsample_size=subsample_size ) as ind:

        # this is the important bit -- vectors of parameters for each star
        flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
        bkg = numpyro.sample( 'skyflux', dist.Normal( 0, 10 ) )
        xcen = numpyro.sample( f'xcenter', dist.Normal( x0, 2 ) ) 
        ycen = numpyro.sample( f'ycenter', dist.Normal( y0, 2 ) )
        
        # the star profiles are radial; this evaluates elliptical R^2                                     
        _xterm = ( xx[ None, ... ] - xcen[ :, None, None ] ) / sx                    
        _yterm = ( yy[ None, ... ] - ycen[ :, None, None ] ) / sy

        norm = (b - 1.) / ( jnp.pi * sx * sy * jnp.sqrt( 1. - rho**2. ) )
        with numpyro.plate( 'x', 25, dim=-2 ) as xi:
          with numpyro.plate( 'y', 25, dim=-1 ) as yi:      

            xterm = _xterm[ :, xi, yi ]                    
            yterm = _yterm[ :, xi, yi ]    
            r2 = (xterm*xterm + yterm*yterm -2.*rho*xterm*yterm) / (1. - rho**2.)

            # now use the R2 values to compute model for each star
            moffat = norm * ( 1. + r2 )**(-b)        
            scene = flux * moffat + bkg
            
            # now evaluate the model
            numpyro.sample( 'obs', dist.Normal( scene, errors[ind,xi,yi] ), obs=images[ind,xi,yi] ) 

but i see the same assertion error being thrown at the same with numpro.plate(‘y’) line.

Does that give any ideas?

 with numpyro.plate( 'stars', num_images, subsample_size=subsample_size ) as ind:

        # this is the important bit -- vectors of parameters for each star
        flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
        bkg = numpyro.sample( 'skyflux', dist.Normal( 0, 10 ) )
        xcen = numpyro.sample( f'xcenter', dist.Normal( x0, 2 ) ) 
        ycen = numpyro.sample( f'ycenter', dist.Normal( y0, 2 ) )
        

In this snippet above, i’m understanding that you want for each star (or image) a 1-D normal vector called flux, a scalar called bkg, and two scalars xcen, ycen. So we want flux to be [N_star], bkg ,xcen, ycen, to be [N_star]? If so, i’d sample those outside the plate, and then index into them in the plate when you compute your likelihood

If i were to do an inventory of parameters, it would look like this:

  • b – scalar – a shape parameter; common to all stars
  • sx – scalar – the size in the x direction; common to all stars
  • sy – scalar – the size in the y direction; common to all stars
  • rho – scalar – the x/y covariance; common to all stars
  • flux – vector (num_stars,) – the flux of each individual star
  • xcen – vector (num_stars,) – the x center of each individual star
  • ycen – vector (num_stars,) – the y center of each individual star
  • skyflux – vector (num_stars,) – the local sky/background around each star.

So the flux, xcen, ycen, and skyflux parameters all should share the same per-star plate. Then, for each star in the plate, i want to create 2d (num_xpix, num_ypix) image to compare back to the data that i have.

Naively, i would think i could call the numpyro.sample in which all of scene, images, and errors have the same size and shape without adding the extra x and y plates. What is the role/value of adding those?

I’m not 100% sure i understand this suggestion, but just to say that what seems to be expensive is the model generation, more so than the log-likelihood evaluation. So i would like to be able to minibatch both the model generation and the evaluation if i can.

My understanding of using SVI is per iteration we are getting some noisy approximation of the ELBO, by sampling from our approximation to the posterior. If we are are doing traditional SVI with no subsampling (i.e batch_size = num_stars), we must compute the gradient wrt all num_stars observations. When we subsample, we can choose batch_size << num_stars and still get a decent enough approximation to the gradient, but the cost per iteration (or update to the guide’s parameters) is much cheaper. I do think that in your case, allocating on the order of 3.5 million to several arrays, might be infeasible for memory reasons. In the traditional case when we want to scale with the # of data points (which isn’t related to the # of parameters in the model, i.e. linear regression) we wouldn’t have to store all data in memory, but just batch_size # of data points. In this case your parameters also scale with the # of data points. I think my suggestion “works” in the sense that in each iteration of batch SVI you’re only updating a certain subset of parameters, but the overall issue of having a large # of parameters is not fixed. I’m not sure how to best achieve that, because at some point you will have to store for example some (num_stars,) shaped arrays in memory. I get what you’re saying about minibatching both the model generation and evaluation, maybe you could consider marginalizing out those latent parameters if everything is Gaussian?

Sorry, my suggestion there was unclear, it should be more like:

 with numpyro.plate( 'stars', num_images, subsample_size=subsample_size ) as ind:

        # this is the important bit -- vectors of parameters for each star
        flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
        bkg = numpyro.sample( 'skyflux', dist.Normal( 0, 10 ) )
        xcen = numpyro.sample( f'xcenter', dist.Normal( x0, 2 ) ) 
        ycen = numpyro.sample( f'ycenter', dist.Normal( y0, 2 ) )
          with numpyro.plate( 'x', 25, dim=-2 ):
               with numpyro.plate( 'y', 25, dim=-1 ):
                    ...
                    numpyro.sample( 'obs', dist.Normal( scene, errors[ind, ...] ), obs=images[ind, ...] ) 

The role of this is to match the
dimensions in the lower most sample call, because within the plate this “obs” variable will now expect (5, 25, 25) instead of (5,) previously.