Source code for jax_galsim.fits

from contextlib import ExitStack, contextmanager

import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from galsim.fits import FitsHeader, closeHDUList, readFile, writeFile  # noqa: F401
from galsim.utilities import galsim_warn

from jax_galsim.core.utils import implements
from jax_galsim.image import Image

# We wrap the galsim FITS read functions to return jax_galsim Image objects.


def _maybe_convert_and_warn(image):
    if image.array.dtype.type not in Image.valid_dtypes:
        galsim_warn(
            "The dtype of the input image is not supported by jax_galsim. "
            "Converting to float64."
        )
        _image = image.copy(dtype=jnp.float64)
        if hasattr(image, "header"):
            _image.header = image.header
        return _image
    else:
        return image


[docs] @implements(_galsim.fits.read) def read(*args, **kwargs): gsimage = _galsim.fits.read(*args, **kwargs) # galsim tests the dtypes against its Image class, so we need to test again here return _maybe_convert_and_warn(Image.from_galsim(gsimage))
[docs] @implements(_galsim.fits.readMulti) def readMulti(*args, **kwargs): gsimage_list = _galsim.fits.readMulti(*args, **kwargs) return [ _maybe_convert_and_warn(Image.from_galsim(gsimage)) for gsimage in gsimage_list ]
[docs] @implements(_galsim.fits.readCube) def readCube(*args, **kwargs): gsimage_list = _galsim.fits.readCube(*args, **kwargs) return [ _maybe_convert_and_warn(Image.from_galsim(gsimage)) for gsimage in gsimage_list ]
# We wrap the galsim FITS write functions to accept jax_galsim Image objects. @contextmanager def _image_as_numpy(image): if isinstance(image, Image): try: orig_array = image._array # convert to numpy so astropy doesn't complain image._array = np.array(image.array, dtype=orig_array.dtype) # some of these check for Image instances, so we hackily set the class # on the way in old_class = image.__class__ image.__class__ = _galsim.Image yield image finally: image.__class__ = old_class image._array = orig_array else: try: yield np.array(image, dtype=image.dtype) finally: pass
[docs] @implements(_galsim.fits.write) def write(*args, **kwargs): if len(args) >= 1 and isinstance(args[0], Image): with _image_as_numpy(args[0]) as image: _galsim.fits.write(image, *args[1:], **kwargs) else: _galsim.fits.write(*args, **kwargs)
[docs] @implements(_galsim.fits.writeMulti) def writeMulti(*args, **kwargs): if len(args) >= 1: with ExitStack() as stack: gsimage_list = [ ( stack.enter_context(_image_as_numpy(image)) if isinstance(image, Image) else image ) for image in args[0] ] _galsim.fits.writeMulti(gsimage_list, *args[1:], **kwargs) else: _galsim.fits.writeMulti(*args, **kwargs)
[docs] @implements(_galsim.fits.writeCube) def writeCube(*args, **kwargs): if len(args) >= 1: with ExitStack() as stack: if isinstance(args[0], list): gsimage_list = [ ( stack.enter_context(_image_as_numpy(image)) if (isinstance(image, Image) or isinstance(image, jax.Array)) else image ) for image in args[0] ] else: gsimage_list = args[0] _galsim.fits.writeCube(gsimage_list, *args[1:], **kwargs) else: _galsim.fits.writeCube(*args, **kwargs)
Image.write = write