JAX-GalSimΒΆ

ci-badge ruff-badge precommit-badge

Warning

This project is still in an early development phase. Please use the reference GalSim implementation for any scientific applications.

JAX-GalSim is a JAX re-implementation of the GalSim galaxy image simulation toolkit. It exposes (nearly) the same API as GalSim while enabling automatic differentiation, JIT compilation, and hardware acceleration via JAX.

Why JAX-GalSim?ΒΆ

⚑ JIT Compilation

Compile simulation pipelines with jax.jit for significant speedups, especially on GPU.

πŸ” Automatic Differentiation

Compute gradients of simulation outputs with respect to galaxy parameters using jax.grad.

πŸ”€ Vectorization

Batch simulations over parameter grids with jax.vmap β€” no explicit loops needed.

Quick InstallΒΆ

pip install jax-galsim
conda install -c conda-forge jax-galsim

See Installation for GPU support and development setup.

Minimal ExampleΒΆ

import jax
import jax_galsim

# Define a galaxy and PSF
gal = jax_galsim.Gaussian(flux=1e5, sigma=2.0)
psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)

# Convolve and draw
final = jax_galsim.Convolve([gal, psf])
image = final.drawImage(scale=0.2)

JAX-GalSim objects are JAX pytrees, so you can JIT-compile and differentiate the entire pipeline:

@jax.jit(static_argnames=['slen', 'fft_size'])
def simulate(flux, sigma, *, slen=21, fft_size=128):
    gsparams = jax_galsim.GSParams(minimum_fft_size=fft_size, maximum_fft_size=fft_size)
    gal = jax_galsim.Gaussian(flux=flux, sigma=sigma)
    psf = jax_galsim.Gaussian(flux=1.0, sigma=1.0)
    return jax_galsim.Convolve([gal, psf]).withGSParams(gsparams) \
                      .drawImage(nx=slen, ny=slen, scale=0.2).array.sum()

# Compute gradients with respect to galaxy parameters
dflux, dsigma = jax.grad(simulate, argnums=(0, 1))(1e5, 2.0)

Getting StartedΒΆ

πŸ“– API Reference

Auto-generated documentation for every public class, function, and module in jax_galsim.

API Reference
πŸ”— GalSim upstream

The original GalSim documentation. Many docstrings in JAX-GalSim are derived from GalSim and expanded with JAX-specific notes.

https://galsim-developers.github.io/GalSim/_build/html
πŸš€ Quick Start

Walk through a complete simulation with JIT, grad, and vmap.

Quick Start
πŸ”ͺ JAX-GalSim - The Sharp Bits πŸ”ͺ

What changes when GalSim runs on JAX β€” immutability, tracing, PyTrees, and more.

πŸ”ͺ JAX-GalSim - The Sharp Bits πŸ”ͺ

About the DocumentationΒΆ

Each class and function that mirrors an upstream GalSim object is annotated with jax_galsim.core.utils.implements(). This decorator copies the original GalSim docstring and prepends any JAX-specific caveats. In the API Reference you will therefore find:

  • A summary and optional πŸ”ͺ JAX-GalSim - The Sharp Bits πŸ”ͺ block at the top of each entry highlighting important caveats.

  • An explicit Parameters table derived from the original GalSim documentation.

  • A collapsible Original GalSim Documentation block containing the full upstream narrative.