I am astronomer trying to measure the shapes of stars in near-infrared imaging. I have about 3.5 million images with 50–600 stars in each, so runtime is important to me. : )
I have cutouts around well-behaved stars, so that my data are a 3d array with shape (num_stars, num_xpix, num_ypix), and the stars are always in or very close to the central pixel. As well as fitting the precise location, flux, and sky/background level for each star, i want to fit a set of 4 parameters that describe the observed shape of all of these stars.
I’m using SVI to do the actual optimisation/regression. This is quite fast, but still not fast enough for me to finish the whole job within my HPC allocation. So i’m trying to get mini-batching working to speed convergence.
Here is a simplified version of my model:
def simple_moffat2d_fit( images, errors ):
# NB. images/errors have shape [num_stars, num_xpix, num_ypix]
# i.e. many images/errors all with the same shape.
# define data geometry (fixed)
num_images = images.shape[0]
xx, yy = jnp.indices( images.shape[1:] )
# define central pixel coordinate (used later)
x0, y0 = images.shape[1]/2., images.shape[2]/2.
# define PSF parameters (common to all stars)
b = jnp.exp( numpyro.sample( 'lnb', dist.Normal( 0.7, 0.2 ) ) )
# the b parameter controls the shape/concentration of the stars
sx = jnp.exp( numpyro.sample( 'lnsigmax', dist.Normal( 1.0, 1.5 ) ) )
sy = jnp.exp( numpyro.sample( 'lnsigmay', dist.Normal( 1.0, 1.5 ) ) )
# the sx/sy parameters control the x/y size of the stars
rho = numpyro.sample( 'rho', dist.Normal( 0., 0.2 ) )
# the rho parameter controls the ellpticity/orientation of the stars
# (sx/y and rho are defined analogously to the multivariate normal)
# this is the important bit -- vectors of parameters for each star
with numpyro.plate( 'stars', num_images ):
flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
# logarithmic prior on the total flux
bkg = numpyro.sample( 'skyflux', dist.Normal( 0, 10 ) )
# normal prior on the local sky/background level
xcen = numpyro.sample( f'xcenter', dist.Normal( x0, 2 ) )
ycen = numpyro.sample( f'ycenter', dist.Normal( y0, 2 ) )
# tight-ish prior on the x/y positions of each star
# now compute the parametric models for each star
# --- i want to move all this under the plate? -->
# juggle indexing to get pixel coords relative to each centre
# NB. xx has shape [num_xpix, num_ypix] and xcen has shape [num_stars]
xterm = ( xx[ None, ... ] - xcen[ :, None, None ] ) / sy
yterm = ( yy[ None, ... ] - ycen[ :, None, None ] ) / sy
r2 = (xterm*xterm + yterm*yterm -2.*rho*xterm*yterm) / (1. - rho**2.)
# r2 now has shape [num_stars, num_xpix, num_ypix]
# the model is defined according to r2 and ~normalised to unity
norm = (b - 1.) / ( jnp.pi * sx * sy * jnp.sqrt( 1. - rho**2. ) )
moffat = norm * ( 1. + r2 )**(-b)
# at last we have the scene as the sum of sky/background and models
scene = flux[ :, None, None ] * moffat + bkg[ :, None, None ]
# NB. scene has shape [num_stars, num_xpix, num_ypix]
# i.e. same as observed images and errors
# so finally we can evaluate the loglikelihood
numpyro.sample( 'obs', dist.Normal( scene, errors ), obs=images )
What i understand from the SVI tutorials is that to get mini-batching working properly, i want to have a structure like this:
# global parameters common to all stars
b = jnp.exp( numpyro.sample( 'lnb', dist.Normal( 0.7, 0.2 ) ) )
# etc.
with numpyro.plate( 'stars', num_images, subsample_size=5 ) as ind:
# parameters specific to each star
flux = 10**numpyro.sample( f'logflux', dist.Normal( 3., 3 ) )
# etc.; these should now have shape [subsample_size,]
scene = flux[ :, None, None ] * moffat + bkg[ :, None, None ]
# now this should have shape [subsample_size, num_xpix, num_ypix] ?
# note use of ind, which i think matches errors/images shape to scene?
numpyro.sample( 'obs', dist.Normal( scene, errors[ind] ), obs=images[ind] )
With print statements, i can see that scene, errors[ind] and images[ind] all have the same shape; i.e. [subsample_size, num_xpix, num_ypix]. But the numpyro.sample line fails with a ValueError
, complaining of Incompatible hapes for broadcasting: shapes=[(5,), (5, 25, 25)]
. (This is for subsample_size=5 and num_xpix=num_ypix=25.)
So … clearly i’m doing something wrong, but i cannot for the life of me find it??? Can anyone set me straight? You will literally save of order a CPU-decade’s worth of emissions … !!
Thanks!