JAX-GalSimΒΆ
Warning
This project is still in an early development phase. Please use the reference GalSim implementation for any scientific applications.
JAX-GalSim is a JAX re-implementation of the GalSim galaxy image simulation toolkit. It exposes (nearly) the same API as GalSim while enabling automatic differentiation, JIT compilation, and hardware acceleration via JAX.
Why JAX-GalSim?ΒΆ
Compile simulation pipelines with jax.jit for significant
speedups, especially on GPU.
Compute gradients of simulation outputs with respect to galaxy
parameters using jax.grad.
Batch simulations over parameter grids with jax.vmap β no
explicit loops needed.
Quick InstallΒΆ
pip install jax-galsim
conda install -c conda-forge jax-galsim
See Installation for GPU support and development setup.
Minimal ExampleΒΆ
import jax
import jax_galsim
# Define a galaxy and PSF
gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0)
psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
# Convolve and draw
final = jax_galsim.Convolve([gal, psf])
image = final.drawImage(scale=0.2)
JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline:
@jax.jit(static_argnames=['slen', 'fft_size'])
def simulate(flux, sigma, *, slen=21, fft_size=128):
gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)
gal = jax_galsim.Gaussian(flux=flux, sigma=sigma)
psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
return jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) \
.drawImage(nx=slen, ny=slen, scale=0.2).array.sum()
# Compute gradients with respect to galaxy parameters
dflux, dsigma = jax.grad(simulate, argnums=(0, 1))(1e5, 2.0)
Getting StartedΒΆ
Auto-generated documentation for every public class, function, and
module in jax_galsim.
The original GalSim documentation. Many docstrings in JAX-GalSim are derived from GalSim and expanded with JAX-specific notes.
Walk through a complete simulation with JIT, grad, and vmap.
What changes when GalSim runs on JAX β immutability, tracing, PyTrees, and more.
About the DocumentationΒΆ
Each class and function that mirrors an upstream GalSim object is annotated
with jax_galsim.core.utils.implements(). This decorator copies the
original GalSim docstring and prepends any JAX-specific caveats. In the API Reference you will therefore find:
A summary and optional πͺ JAX-GalSim - The Sharp Bits πͺ block at the top of each entry highlighting important caveats.
An explicit Parameters table derived from the original GalSim documentation.
A collapsible Original GalSim Documentation block containing the full upstream narrative.