Quick Start¶

A complete galaxy image simulation, then JAX transformations (jit, grad, vmap) on top.

A Simple Simulation¶

A Gaussian galaxy convolved with a Gaussian PSF, drawn and noised — equivalent to GalSim’s demo1.py.

import jax_galsim

# Galaxy parameters
gal_flux = 1e5      # total counts
gal_sigma = 2.0     # arcsec
psf_sigma = 1.0     # arcsec
pixel_scale = 0.2   # arcsec/pixel
noise_sigma = 30.0  # counts per pixel

# Define profiles
gal = jax_galsim.Gaussian(flux=gal_flux, sigma=gal_sigma)
psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma)

# Convolve galaxy with PSF
final = jax_galsim.Convolve([gal, psf])

# Draw the image
image = final.drawImage(scale=pixel_scale)

# Add Gaussian noise
image.addNoise(jax_galsim.GaussianNoise(sigma=noise_sigma))

# Write to FITS
image.write("output/demo1.fits")

Most GalSim code translates directly by replacing import galsim with import jax_galsim.

JIT Compilation¶

Wrap your simulation in jax.jit to compile it into an optimised XLA computation:

import jax

@jax.jit(static_argnames=['slen', 'fft_size'])
def simulate(flux, sigma, *, slen, fft_size):
    gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)
    gal = jax_galsim.Gaussian(flux=flux, sigma=sigma)
    psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
    final = jax_galsim.Convolve([gal, psf])
    return final.drawImage(nx=slen, ny=slen, scale=0.2)

# First call compiles; subsequent calls are fast
image = simulate(1e5, 2.0, slen=21, fft_size=128)

Note

Any arguments that affect control flow (like image size) must be marked as static_argnames for JIT to work.

Here is an alternative using functools.partial:

from jax import jit
from functools import partial

def simulate(flux, sigma, *, slen, fft_size):
    gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)
    gal = jax_galsim.Gaussian(flux=flux, sigma=sigma)
    psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
    final = jax_galsim.Convolve([gal, psf])
    return final.drawImage(nx=slen, ny=slen, scale=0.2)

simulate_jitted = jit(partial(simulate, slen=21, fft_size=128))
image = simulate_jitted(1e5, 2.0)

Automatic Differentiation¶

Compute gradients of any scalar output with respect to parameters:

def total_flux(gal_sigma, psf_sigma):
    gal = jax_galsim.Gaussian(flux=1e5, sigma=gal_sigma)
    psf = jax_galsim.Gaussian(flux=1.0, sigma=psf_sigma)
    final = jax_galsim.Convolve([gal, psf])
    image = final.drawImage(scale=0.2)
    return image.array.sum()

# Gradient of total image flux with respect to both sigmas
grad_fn = jax.grad(total_flux, argnums=(0, 1))
d_gal, d_psf = grad_fn(2.0, 1.0)

Useful for fitting galaxy models to data via gradient-based optimisation.

Vectorization with vmap¶

Batch-simulate galaxies with different parameters without explicit loops:

import jax.numpy as jnp

sigmas = jnp.linspace(1.0, 4.0, 10)

@jax.jit
@jax.vmap
def batch_simulate(sigma):
    gsparams = jax_galsim.GSParams(minimum_fft_size=128, maximum_fft_size=128)
    gal = jax_galsim.Gaussian(flux=1e5, sigma=sigma)
    psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
    final = jax_galsim.Convolve([gal, psf]).withGSParams(gsparams)
    return final.drawImage(scale=0.2, nx=64, ny=64).array

# Simulate all 10 galaxies in parallel
images = batch_simulate(sigmas)  # shape: (10, 64, 64)

Next Steps¶