import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax_galsim.bounds import Bounds, BoundsD, BoundsI
from jax_galsim.core.utils import (
cast_numpy_array_to_native_byte_order,
ensure_hashable,
has_tracers,
implements,
)
from jax_galsim.errors import GalSimImmutableError
from jax_galsim.position import PositionI
from jax_galsim.utilities import parse_pos_args
from jax_galsim.wcs import BaseWCS, PixelScale
IMAGE_LAX_DOCS = """\
Contrary to GalSim native Image, this implementation does not support
sharing of the underlying array between different Images or views.
This is due to the fact that in JAX numpy arrays are immutable, so any
operation applied to this Image will create a new ``jnp.ndarray``. Making
a view via ``.view()`` will raise an error. Instead, use the ``.copy()``
method. The ``Image.subImage()`` method will return a copy.
"""
[docs]
@implements(
_galsim.Image,
lax_description=IMAGE_LAX_DOCS,
)
@register_pytree_node_class
class Image(object):
_alias_dtypes = {
int: jnp.int32, # So that user gets what they would expect
float: jnp.float64, # if using dtype=int or float or complex
complex: jnp.complex128,
jnp.int64: jnp.int32, # Not equivalent, but will convert
np.uint16: jnp.uint16,
np.uint32: jnp.uint32,
np.int16: jnp.int16,
np.int32: jnp.int32,
np.float32: jnp.float32,
np.float64: jnp.float64,
np.complex64: jnp.complex64,
np.complex128: jnp.complex128,
}
_valid_dtypes = [
jnp.int32,
jnp.float64,
jnp.uint16,
jnp.uint32,
jnp.int16,
jnp.float32,
jnp.complex64,
jnp.complex128,
]
valid_dtypes = _valid_dtypes
def __init__(self, *args, **kwargs):
# this one is specific to jax-galsim and is used to disable bounds checking
# we use an underscore to denote that it is a private argument
_check_bounds = kwargs.pop("_check_bounds", True)
# Parse the args, kwargs
ncol = None
nrow = None
bounds = None
array = None
image = None
if len(args) > 2:
raise TypeError("Error, too many unnamed arguments to Image constructor")
elif len(args) == 2:
ncol = args[0]
nrow = args[1]
xmin = kwargs.pop("xmin", 1)
ymin = kwargs.pop("ymin", 1)
elif len(args) == 1:
if isinstance(args[0], np.ndarray):
array = jnp.array(cast_numpy_array_to_native_byte_order(args[0]))
array, xmin, ymin = self._get_xmin_ymin(
array, kwargs, check_bounds=_check_bounds
)
elif isinstance(args[0], jnp.ndarray):
array = args[0]
array, xmin, ymin = self._get_xmin_ymin(
array, kwargs, check_bounds=_check_bounds
)
elif isinstance(args[0], BoundsI):
bounds = args[0]
elif isinstance(args[0], (list, tuple)):
array = jnp.array(args[0])
array, xmin, ymin = self._get_xmin_ymin(
array, kwargs, check_bounds=_check_bounds
)
elif isinstance(args[0], Image):
image = args[0]
else:
raise TypeError(
"Unable to parse %s as an array, bounds, or image." % args[0]
)
else:
if "array" in kwargs:
array = kwargs.pop("array")
array, xmin, ymin = self._get_xmin_ymin(
array, kwargs, check_bounds=_check_bounds
)
elif "bounds" in kwargs:
bounds = kwargs.pop("bounds")
elif "image" in kwargs:
image = kwargs.pop("image")
else:
ncol = kwargs.pop("ncol", None)
nrow = kwargs.pop("nrow", None)
xmin = kwargs.pop("xmin", 1)
ymin = kwargs.pop("ymin", 1)
# Pop off the other valid kwargs:
dtype = kwargs.pop("dtype", None)
init_value = kwargs.pop("init_value", None)
scale = kwargs.pop("scale", None)
wcs = kwargs.pop("wcs", None)
self._is_const = kwargs.pop("make_const", False)
# Check that we got them all
if kwargs:
if "copy" in kwargs.keys() and not kwargs["copy"]:
raise TypeError(
"'copy=False' is not a valid keyword argument for the JAX-GalSim version of the Image constructor"
)
else:
# remove it since we used it
kwargs.pop("copy", None)
if kwargs:
raise TypeError(
"Image constructor got unexpected keyword arguments: %s", kwargs
)
# Figure out what dtype we want:
dtype = self._alias_dtypes.get(dtype, dtype)
if dtype is not None and dtype not in self._valid_dtypes:
raise _galsim.GalSimValueError("Invlid dtype.", dtype, self._valid_dtypes)
if array is not None:
if dtype is None:
dtype = array.dtype.type
if dtype in self._alias_dtypes:
dtype = self._alias_dtypes[dtype]
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype)
elif dtype not in self._valid_dtypes:
raise _galsim.GalSimValueError(
"Invalid dtype of provided array.",
array.dtype,
self._valid_dtypes,
)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype)
# Be careful here: we have to watch out for little-endian / big-endian issues.
# The path of least resistance is to check whether the array.dtype is equal to the
# native one (using the dtype.isnative flag), and if not, make a new array that has a
# type equal to the same one but with the appropriate endian-ness.
if not array.dtype.isnative:
array = array.astype(array.dtype.newbyteorder("="))
self._dtype = array.dtype.type
elif image is not None:
if not isinstance(image, Image):
raise TypeError("image must be an Image")
# we do less checking here since we already have a valid image
if dtype is None:
self._dtype = image.dtype
else:
self._dtype = dtype
elif dtype is not None:
self._dtype = dtype
else:
self._dtype = jnp.float32
# Construct the image attribute
if ncol is not None or nrow is not None:
if ncol is None or nrow is None:
raise _galsim.GalSimIncompatibleValuesError(
"Both nrow and ncol must be provided", ncol=ncol, nrow=nrow
)
if ncol != int(ncol) or nrow != int(nrow):
raise TypeError("nrow, ncol must be integers")
ncol = int(ncol)
nrow = int(nrow)
self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype)
if not has_tracers(xmin) and not has_tracers(ymin):
self._bounds = BoundsI(
xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True
)
else:
self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow)
if init_value:
self._array = self._array.at[...].add(init_value)
elif bounds is not None:
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")
self._array = self._make_empty(bounds.numpyShape(), dtype=self._dtype)
self._bounds = bounds
if init_value:
self._array = self._array.at[...].add(init_value)
elif array is not None:
self._array = array.view()
nrow, ncol = array.shape
if not has_tracers(xmin) and not has_tracers(ymin):
self._bounds = BoundsI(
xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True
)
else:
self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow)
if init_value is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot specify init_value with array",
init_value=init_value,
array=array,
)
elif image is not None:
if not isinstance(image, Image):
raise TypeError("image must be an Image")
if init_value is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot specify init_value with image",
init_value=init_value,
image=image,
)
if wcs is None and scale is None:
wcs = image.wcs
self._bounds = image.bounds
if dtype is None:
self._dtype = image.dtype
else:
# Allow dtype to force a retyping of the provided image
# e.g. im = ImageF(...)
# im2 = ImageD(im)
self._dtype = dtype
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = _safe_cast(
image.array, jnp.issubdtype(self._dtype, jnp.integer), self._dtype
)
else:
self._array = jnp.zeros(shape=(1, 1), dtype=self._dtype)
self._bounds = BoundsI()
if init_value is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot specify init_value without setting an initial size",
init_value=init_value,
ncol=ncol,
nrow=nrow,
bounds=bounds,
)
# Construct the wcs attribute
if scale is not None:
if wcs is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide both scale and wcs to Image constructor",
wcs=wcs,
scale=scale,
)
self.wcs = PixelScale(float(scale))
else:
if wcs is not None and not isinstance(wcs, BaseWCS):
raise TypeError("wcs parameters must be a galsim.BaseWCS instance")
self.wcs = wcs
@staticmethod
def _get_xmin_ymin(array, kwargs, check_bounds=True):
"""A helper function for parsing xmin, ymin, bounds options with a given array"""
if not isinstance(array, (np.ndarray, jnp.ndarray)):
raise TypeError("array must be a ndarray instance")
xmin = kwargs.pop("xmin", 1)
ymin = kwargs.pop("ymin", 1)
if "bounds" in kwargs:
b = kwargs.pop("bounds")
if not isinstance(b, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")
if (
check_bounds
and b.isDefined()
and not has_tracers(b.xmin)
and not has_tracers(b.ymin)
and not has_tracers(b.xmax)
and not has_tracers(b.ymax)
):
# We need to disable this when jitting
if b.xmax - b.xmin + 1 != array.shape[1]:
raise _galsim.GalSimIncompatibleValuesError(
"Shape of array is inconsistent with provided bounds",
array=array,
bounds=b,
)
if b.ymax - b.ymin + 1 != array.shape[0]:
raise _galsim.GalSimIncompatibleValuesError(
"Shape of array is inconsistent with provided bounds",
array=array,
bounds=b,
)
if b.isDefined():
xmin = b.xmin
ymin = b.ymin
else:
# Indication that array is formally undefined, even though provided.
if "dtype" not in kwargs:
kwargs["dtype"] = array.dtype.type
array = None
xmin = None
ymin = None
elif array.shape[1] == 0:
# Another way to indicate that we don't have a defined image.
if "dtype" not in kwargs:
kwargs["dtype"] = array.dtype.type
array = None
xmin = None
ymin = None
return array, xmin, ymin
def __repr__(self):
s = "galsim.Image(bounds=%r" % self.bounds
if self.bounds.isDefined():
s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),)
s += ", wcs=%r" % self.wcs
if self.isconst:
s += ", make_const=True"
s += ")"
return s
def __str__(self):
# Get the type name without the <type '...'> part.
t = str(self.dtype).split("'")[1]
if self.wcs is not None and self.wcs._isPixelScale:
return "galsim.Image(bounds=%s, scale=%s, dtype=%s)" % (
self.bounds,
ensure_hashable(self.scale),
t,
)
else:
return "galsim.Image(bounds=%s, wcs=%s, dtype=%s)" % (
self.bounds,
self.wcs,
t,
)
# Not immutable object. So shouldn't be used as a hash.
__hash__ = None
# Read-only attributes:
@property
@implements(_galsim.Image.dtype)
def dtype(self):
return self._dtype
@property
@implements(_galsim.Image.bounds)
def bounds(self):
return self._bounds
@property
@implements(_galsim.Image.array)
def array(self):
return self._array
@array.setter
def array(self, other):
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self._array.at[...].set(
_safe_cast(other, self.isinteger, self.array.dtype)
)
@property
@implements(_galsim.Image.nrow)
def nrow(self):
return self._array.shape[0]
@property
@implements(_galsim.Image.ncol)
def ncol(self):
return self._array.shape[1]
@property
@implements(_galsim.Image.isconst)
def isconst(self):
return self._is_const
@property
@implements(_galsim.Image.iscomplex)
def iscomplex(self):
return self._array.dtype.kind == "c"
@property
@implements(_galsim.Image.isinteger)
def isinteger(self):
return self._array.dtype.kind in ("i", "u")
@property
@implements(
_galsim.Image.iscontiguous, lax_description="In JAX all arrays are contiguous."
)
def iscontiguous(self):
return True # In JAX all arrays are contiguous (almost)
# Allow scale to work as a PixelScale wcs.
@property
@implements(_galsim.Image.scale)
def scale(self):
try:
return self.wcs.scale
except Exception:
if self.wcs:
raise _galsim.GalSimError(
"image.wcs is not a simple PixelScale; scale is undefined."
)
else:
return None
@scale.setter
def scale(self, value):
if self.wcs is not None and not self.wcs._isPixelScale:
raise _galsim.GalSimError(
"image.wcs is not a simple PixelScale; scale is undefined."
)
else:
self.wcs = PixelScale(value)
# Convenience functions
@property
@implements(_galsim.Image.xmin)
def xmin(self):
return self._bounds.xmin
@property
@implements(_galsim.Image.xmax)
def xmax(self):
return self._bounds.xmax
@property
@implements(_galsim.Image.ymin)
def ymin(self):
return self._bounds.ymin
@property
@implements(_galsim.Image.ymax)
def ymax(self):
return self._bounds.ymax
@property
@implements(_galsim.Image.outer_bounds)
def outer_bounds(self):
return BoundsD(
self.xmin - 0.5, self.xmax + 0.5, self.ymin - 0.5, self.ymax + 0.5
)
# real, imag for everything, even real images.
@property
@implements(_galsim.Image.real)
def real(self):
return self.__class__(
self.array.real, bounds=self.bounds, wcs=self.wcs, make_const=self._is_const
)
@property
@implements(_galsim.Image.imag)
def imag(self):
return self.__class__(
self.array.imag,
bounds=self.bounds,
wcs=self.wcs,
# for real images, the imaginary part is always zero and immutable
make_const=self._is_const or (not self.iscomplex),
)
@property
@implements(_galsim.Image.conjugate)
def conjugate(self):
return self.__class__(self.array.conjugate(), bounds=self.bounds, wcs=self.wcs)
[docs]
@implements(
_galsim.Image.copy,
lax_description=(
"JAX-GalSim supports extra keyword arguments to ``.copy`` so "
"that users can make copies of images while also changing the image "
"properties (e.g., the wcs). The extra keywords behave exactly like "
"those of ``Image.view``."
),
)
def copy(
self,
scale=None,
wcs=None,
origin=None,
center=None,
dtype=None,
make_const=False,
contiguous=False,
):
if origin is not None and center is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide both center and origin", center=center, origin=origin
)
if scale is not None:
if wcs is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide both scale and wcs", scale=scale, wcs=wcs
)
wcs = PixelScale(scale)
elif wcs is not None:
if not isinstance(wcs, BaseWCS):
raise TypeError("wcs parameters must be a galsim.BaseWCS instance")
else:
wcs = self.wcs
# Figure out the dtype for the return Image
dtype = dtype if dtype else self.dtype
# If currently empty, just return a new empty image.
if not self.bounds.isDefined():
return Image(wcs=wcs, dtype=dtype, make_const=make_const)
# Recast the array type if necessary
array = self.array.copy()
if dtype != array.dtype:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype)
elif contiguous:
# this is a noop since all jax arrays are contiguous
pass
else:
# do nothing here since we made copy above
pass
# Make the return Image - already made copy above
ret = self.__class__(array, bounds=self.bounds, wcs=wcs, make_const=make_const)
# Update the origin if requested
if origin is not None:
ret.setOrigin(origin)
elif center is not None:
ret.setCenter(center)
return ret
[docs]
@implements(_galsim.Image.get_pixel_centers)
def get_pixel_centers(self):
x, y = jnp.meshgrid(
jnp.arange(self.array.shape[1], dtype=float),
jnp.arange(self.array.shape[0], dtype=float),
)
x += self.bounds.xmin
y += self.bounds.ymin
return x, y
def _make_empty(self, shape, dtype):
"""Helper function to make an empty numpy array of the given shape."""
if np.prod(shape) == 0:
# galsim forces degenerate images to have at least 1 pixel
return jnp.zeros(shape=(1, 1), dtype=dtype)
else:
return jnp.zeros(shape=shape, dtype=dtype)
[docs]
@implements(_galsim.Image.resize)
def resize(self, bounds, wcs=None):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")
self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype)
self._bounds = bounds
if wcs is not None:
self.wcs = wcs
[docs]
@implements(_galsim.Image.subImage)
def subImage(self, bounds):
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access subImage of undefined image"
)
if (
not has_tracers(self.bounds.xmin)
and not has_tracers(self.bounds.xmax)
and not has_tracers(self.bounds.ymin)
and not has_tracers(self.bounds.ymax)
and not has_tracers(bounds.xmin)
and not has_tracers(bounds.xmax)
and not has_tracers(bounds.ymin)
and not has_tracers(bounds.ymax)
and not self.bounds.includes(bounds)
):
raise _galsim.GalSimBoundsError(
"Attempt to access subImage not (fully) in image", bounds, self.bounds
)
if self.bounds.isStatic() and bounds.isStatic():
i1 = bounds.ymin - self.ymin
i2 = bounds.ymax - self.ymin + 1
j1 = bounds.xmin - self.xmin
j2 = bounds.xmax - self.xmin + 1
subarray = self.array[i1:i2, j1:j2]
else:
start_inds = (
bounds.ymin - self.ymin,
bounds.xmin - self.xmin,
)
shape = bounds.numpyShape()
subarray = jax.lax.dynamic_slice(self.array, start_inds, shape)
# NB. The wcs is still accurate, since the sub-image uses the same (x,y) values
# as the original image did for those pixels. It's only once you recenter or
# reorigin that you need to update the wcs. So that's taken care of in im.shift.
return self.__class__(subarray, bounds=bounds, wcs=self.wcs)
[docs]
@implements(_galsim.Image.setSubImage)
def setSubImage(self, bounds, rhs):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access values of an undefined image"
)
if (
not has_tracers(self.bounds.xmin)
and not has_tracers(self.bounds.xmax)
and not has_tracers(self.bounds.ymin)
and not has_tracers(self.bounds.ymax)
and not has_tracers(bounds.xmin)
and not has_tracers(bounds.xmax)
and not has_tracers(bounds.ymin)
and not has_tracers(bounds.ymax)
and not self.bounds.includes(bounds)
):
raise _galsim.GalSimBoundsError(
"Attempt to access subImage not (fully) in image", bounds, self.bounds
)
if not isinstance(rhs, Image):
raise TypeError("Trying to copyFrom a non-image")
if bounds.numpyShape() != rhs.bounds.numpyShape():
raise _galsim.GalSimIncompatibleValuesError(
"Trying to copy images that are not the same shape",
self_image=self,
rhs=rhs,
)
if self.bounds.isStatic() and bounds.isStatic():
i1 = bounds.ymin - self.ymin
i2 = bounds.ymax - self.ymin + 1
j1 = bounds.xmin - self.xmin
j2 = bounds.xmax - self.xmin + 1
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self._array.at[i1:i2, j1:j2].set(
_safe_cast(
rhs.array, jnp.issubdtype(self.dtype, jnp.integer), self.dtype
)
)
else:
start_inds = (
bounds.ymin - self.ymin,
bounds.xmin - self.xmin,
)
self._array = jax.lax.dynamic_update_slice(
self.array,
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
_safe_cast(
rhs.array, jnp.issubdtype(self.dtype, jnp.integer), self.dtype
),
start_inds,
)
def __getitem__(self, *args):
"""Return either a subimage or a single pixel value.
For example,::
>>> subimage = im[galsim.BoundsI(3,7,3,7)]
>>> value = im[galsim.PositionI(5,5)]
>>> value = im[5,5]
"""
if len(args) == 1:
if isinstance(args[0], BoundsI):
return self.subImage(*args)
elif isinstance(args[0], PositionI):
return self(*args)
elif isinstance(args[0], tuple):
return self.getValue(*args[0])
else:
raise TypeError(
"image[index] only accepts BoundsI or PositionI for the index"
)
elif len(args) == 2:
return self(*args)
else:
raise TypeError("image[..] requires either 1 or 2 args")
def __setitem__(self, *args):
"""Set either a subimage or a single pixel to new values.
For example,::
>>> im[galsim.BoundsI(3,7,3,7)] = im2
>>> im[galsim.PositionI(5,5)] = 17.
>>> im[5,5] = 17.
"""
if len(args) == 2:
if isinstance(args[0], BoundsI):
self.setSubImage(*args)
elif isinstance(args[0], PositionI):
self.setValue(*args)
elif isinstance(args[0], tuple):
self.setValue(*args)
else:
raise TypeError(
"image[index] only accepts BoundsI or PositionI for the index"
)
elif len(args) == 3:
return self.setValue(*args)
else:
raise TypeError("image[..] requires either 1 or 2 args")
[docs]
@implements(_galsim.Image.wrap)
def wrap(self, bounds, hermitian=False):
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")
# Get this at the start to check for invalid bounds and raise the exception before
# possibly writing data past the edge of the image.
if not hermitian:
return self._wrap(bounds, False, False, None)
elif hermitian == "x":
if not has_tracers(self.bounds.xmin) and self.bounds.xmin != 0:
raise _galsim.GalSimIncompatibleValuesError(
"hermitian == 'x' requires self.bounds.xmin == 0",
hermitian=hermitian,
bounds=self.bounds,
)
if not has_tracers(bounds.xmin) and bounds.xmin != 0:
raise _galsim.GalSimIncompatibleValuesError(
"hermitian == 'x' requires bounds.xmin == 0",
hermitian=hermitian,
bounds=bounds,
)
return self._wrap(bounds, True, False, 2 * bounds.xmax)
elif hermitian == "y":
if not has_tracers(self.bounds.ymin) and self.bounds.ymin != 0:
raise _galsim.GalSimIncompatibleValuesError(
"hermitian == 'y' requires self.bounds.ymin == 0",
hermitian=hermitian,
bounds=self.bounds,
)
if not has_tracers(bounds.ymin) and bounds.ymin != 0:
raise _galsim.GalSimIncompatibleValuesError(
"hermitian == 'y' requires bounds.ymin == 0",
hermitian=hermitian,
bounds=bounds,
)
return self._wrap(bounds, False, True, 2 * bounds.ymax)
else:
raise _galsim.GalSimValueError(
"Invalid value for hermitian", hermitian, (False, "x", "y")
)
@implements(_galsim.Image._wrap)
def _wrap(self, bounds, hermx, hermy, hermitian_wrap_size):
if not hermx and not hermy:
from jax_galsim.core.wrap_image import wrap_nonhermitian
self._array = self._array.at[...].set(
wrap_nonhermitian(
self._array,
# zero indexed location of subimage
bounds.xmin - self.xmin,
bounds.ymin - self.ymin,
bounds.deltax,
bounds.deltay,
)
)
elif hermx and not hermy:
from jax_galsim.core.wrap_image import wrap_hermitian_x
self._array = self._array.at[...].set(
wrap_hermitian_x(
self._array,
-self.xmax,
self.ymin,
-bounds.xmax + 1,
bounds.ymin,
hermitian_wrap_size,
bounds.deltay,
)
)
elif not hermx and hermy:
from jax_galsim.core.wrap_image import wrap_hermitian_y
self._array = self._array.at[...].set(
wrap_hermitian_y(
self._array,
self.xmin,
-self.ymax,
bounds.xmin,
-bounds.ymax + 1,
bounds.deltax,
hermitian_wrap_size,
)
)
return self.subImage(bounds)
[docs]
@implements(
_galsim.Image.calculate_fft,
lax_description="JAX-GalSim does not support forward FFTs of complex dtypes.",
)
def calculate_fft(self):
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"calculate_fft requires that the image have defined bounds."
)
if self.wcs is None:
raise _galsim.GalSimError("calculate_fft requires that the scale be set.")
if not self.wcs._isPixelScale:
raise _galsim.GalSimError(
"calculate_fft requires that the image has a PixelScale wcs."
)
if self.dtype in [np.complex64, np.complex128, complex]:
raise _galsim.GalSimNotImplementedError(
"JAX-GalSim does not support forward FFTs of complex dtypes."
)
# TODO: figure out how to do FFT at fixed size and then reconstruct
# the result
No2 = max(
max(
-self.bounds.xmin,
self.bounds.xmax + 1,
),
max(
-self.bounds.ymin,
self.bounds.ymax + 1,
),
)
full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2)
if self.bounds == full_bounds:
# Then the image is already in the shape we need.
ximage = self
else:
# Then we pad out with zeros
ximage = Image(full_bounds, dtype=self.dtype, init_value=0)
ximage[self.bounds] = self[self.bounds]
dx = self.scale
# dk = 2pi / (N dk)
dk = jnp.pi / (No2 * dx)
out = Image(
BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2),
dtype=np.complex128,
scale=dk,
)
# we shift the image before and after the FFT to match the layout of the modes
# used by GalSim
out._array = out._array.at[...].set(
jnp.fft.fftshift(jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0)
)
out *= dx * dx
out.setOrigin(0, -No2)
return out
[docs]
@implements(_galsim.Image.calculate_inverse_fft)
def calculate_inverse_fft(self):
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"calculate_fft requires that the image have defined bounds."
)
if self.wcs is None:
raise _galsim.GalSimError(
"calculate_inverse_fft requires that the scale be set."
)
if not self.wcs._isPixelScale:
raise _galsim.GalSimError(
"calculate_inverse_fft requires that the image has a PixelScale wcs."
)
if not self.bounds.includes(0, 0):
raise _galsim.GalSimBoundsError(
"calculate_inverse_fft requires that the image includes (0,0)",
PositionI(0, 0),
self.bounds,
)
No2 = max(
max(self.bounds.xmax, -self.bounds.ymin),
self.bounds.ymax,
)
target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2)
if self.bounds == target_bounds:
# Then the image is already in the shape we need.
kimage = self
else:
# Then we can pad out with zeros and wrap to get this in the form we need.
full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1)
kimage = Image(full_bounds, dtype=self.dtype, init_value=0)
posx_bounds = BoundsI(
xmin=0,
xmax=self.bounds.xmax,
ymin=self.bounds.ymin,
ymax=self.bounds.ymax,
)
kimage[posx_bounds] = self[posx_bounds]
kimage = kimage._wrap(target_bounds, True, False, 2 * No2)
dk = self.scale
# dx = 2pi / (N dk)
dx = jnp.pi / (No2 * dk)
# In GalSim, they use inplace FFTW transforms which require the
# array that holds the input/output to have extra padding on the
# x dimension.
# jax-galsim does not need the padding since it does not use an
# inplace FFT. Thus we do not use the
# padding.
out = Image(
bounds=BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2),
dtype=float,
scale=dx,
# we shift the image before and after the FFT to match the layout used by galsim
array=jnp.fft.fftshift(
jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0))
)
* (dk * No2 / jnp.pi) ** 2,
)
out.setCenter(0, 0)
return out
[docs]
@classmethod
@implements(_galsim.Image.good_fft_size)
def good_fft_size(cls, input_size):
# we use the math module here since this function should not be jitted.
import math
# Reference from GalSim C++
# https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/Image.cpp#L1009
# Reduce slightly to eliminate potential rounding errors:
insize = (1.0 - 1.0e-5) * input_size
log2n = math.log(2.0) * math.ceil(math.log(insize) / math.log(2.0))
log2n3 = math.log(3.0) + math.log(2.0) * math.ceil(
(math.log(insize) - math.log(3.0)) / math.log(2.0)
)
log2n3 = max(log2n3, math.log(6.0)) # must be even number
Nk = max(int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)), 2)
return Nk
[docs]
@implements(_galsim.Image.copyFrom)
def copyFrom(self, rhs):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(rhs, Image):
raise TypeError("Trying to copyFrom a non-image")
if self.bounds.numpyShape() != rhs.bounds.numpyShape():
raise _galsim.GalSimIncompatibleValuesError(
"Trying to copy images that are not the same shape",
self_image=self,
rhs=rhs,
)
self._copyFrom(rhs)
def _copyFrom(self, rhs):
"""Same as copyFrom, but no sanity checks."""
self._array = self._array.at[...].set(
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
_safe_cast(rhs._array, self.isinteger, self.array.dtype)
)
[docs]
@implements(
_galsim.Image.view,
lax_description=(
"JAX-GalSim does not support image views. This "
"method will raise an error if called."
),
)
def view(
self,
scale=None,
wcs=None,
origin=None,
center=None,
dtype=None,
make_const=False,
contiguous=False,
):
raise NotImplementedError(
"JAX-GalSim does not support views of images! Use ``.copy`` instead."
)
[docs]
@implements(_galsim.Image.shift)
def shift(self, *args, **kwargs):
delta = parse_pos_args(args, kwargs, "dx", "dy", integer=True)
self._shift(delta)
@implements(_galsim.Image._shift)
def _shift(self, delta):
self._bounds = self._bounds.shift(delta)
if self.wcs is not None:
self.wcs = self.wcs.shiftOrigin(delta)
[docs]
@implements(_galsim.Image.setCenter)
def setCenter(self, *args, **kwargs):
cen = parse_pos_args(args, kwargs, "xcen", "ycen", integer=True)
self._shift(cen - self.center)
[docs]
@implements(_galsim.Image.setOrigin)
def setOrigin(self, *args, **kwargs):
origin = parse_pos_args(args, kwargs, "x0", "y0", integer=True)
self._shift(origin - self.origin)
@property
@implements(_galsim.Image.center)
def center(self):
return self.bounds.center
@property
@implements(_galsim.Image.true_center)
def true_center(self):
return self.bounds.true_center
@property
@implements(_galsim.Image.origin)
def origin(self):
return self.bounds.origin
def __call__(self, *args, **kwargs):
"""Get the pixel value at given position
The arguments here may be either (x, y) or a PositionI instance.
Or you can provide x, y as named kwargs.
"""
pos = parse_pos_args(args, kwargs, "x", "y", integer=True)
return self.getValue(pos.x, pos.y)
[docs]
@implements(_galsim.Image.getValue)
def getValue(self, x, y):
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access values of an undefined image"
)
if not self.bounds.includes(x, y):
raise _galsim.GalSimBoundsError(
"Attempt to access position not in bounds of image.",
PositionI(x, y),
self.bounds,
)
return self._getValue(x, y)
@implements(_galsim.Image._getValue)
def _getValue(self, x, y):
return self.array[y - self.ymin, x - self.xmin]
[docs]
@implements(_galsim.Image.setValue)
def setValue(self, *args, **kwargs):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to set value of an undefined image"
)
pos, value = parse_pos_args(
args, kwargs, "x", "y", integer=True, others=["value"]
)
if not self.bounds.includes(pos):
raise _galsim.GalSimBoundsError(
"Attempt to set position not in bounds of image", pos, self.bounds
)
self._setValue(pos.x, pos.y, value)
@implements(_galsim.Image._setValue)
def _setValue(self, x, y, value):
self._array = self._array.at[y - self.ymin, x - self.xmin].set(value)
[docs]
@implements(_galsim.Image.addValue)
def addValue(self, *args, **kwargs):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to set value of an undefined image"
)
pos, value = parse_pos_args(
args, kwargs, "x", "y", integer=True, others=["value"]
)
if not self.bounds.includes(pos):
raise _galsim.GalSimBoundsError(
"Attempt to set position not in bounds of image", pos, self.bounds
)
self._addValue(pos.x, pos.y, value)
@implements(_galsim.Image._addValue)
def _addValue(self, x, y, value):
self._array = self._array.at[y - self.ymin, x - self.xmin].add(value)
[docs]
@implements(_galsim.Image.fill)
def fill(self, value):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to set values of an undefined image"
)
self._fill(value)
@implements(_galsim.Image._fill)
def _fill(self, value):
self._array = self._array.at[...].set(value)
[docs]
@implements(_galsim.Image.setZero)
def setZero(self):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
self._fill(0)
[docs]
@implements(_galsim.Image.invertSelf)
def invertSelf(self):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to set values of an undefined image"
)
self._invertSelf()
@implements(_galsim.Image._invertSelf)
def _invertSelf(self):
msk = self._array == 0
safe_array = jnp.where(
msk,
1.0,
self._array,
)
self._array = self._array.at[...].set(
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
_safe_cast(
(jnp.where(msk, 0.0, 1.0 / safe_array)),
jnp.issubdtype(self._array.dtype, jnp.integer),
self._array.dtype,
)
)
[docs]
@implements(_galsim.Image.replaceNegative)
def replaceNegative(self, replace_value=0):
if self.isconst:
raise GalSimImmutableError("Cannot modify an immutable Image", self)
self._array = self.array.at[self.array < 0].set(replace_value)
def __eq__(self, other):
# Note that numpy.array_equal can return True if the dtypes of the two arrays involved are
# different, as long as the contents of the two arrays are logically the same. For example:
#
# >>> double_array = np.arange(1024).reshape(32, 32)*np.pi
# >>> int_array = np.arange(1024).reshape(32, 32)
# >>> assert galsim.ImageD(int_array) == galsim.ImageF(int_array) # passes
# >>> assert galsim.ImageD(double_array) == galsim.ImageF(double_array) # fails
return self is other or (
isinstance(other, Image)
and self.bounds == other.bounds
and self.wcs == other.wcs
and (
not self.bounds.isDefined() or jnp.array_equal(self.array, other.array)
)
and self.isconst == other.isconst
)
def __ne__(self, other):
return not self.__eq__(other)
[docs]
@implements(_galsim.Image.transpose)
def transpose(self):
bT = self.bounds.__class__(
xmin=self.ymin,
deltax=self.bounds.deltay,
ymin=self.xmin,
deltay=self.bounds.deltax,
)
return _Image(self.array.T, bT, None)
[docs]
@implements(_galsim.Image.flip_lr)
def flip_lr(self):
return _Image(self.array.at[:, ::-1].get(), self._bounds, None)
[docs]
@implements(_galsim.Image.flip_ud)
def flip_ud(self):
return _Image(self.array.at[::-1, :].get(), self._bounds, None)
[docs]
@implements(_galsim.Image.rot_cw)
def rot_cw(self):
bT = self.bounds.__class__(
xmin=self.ymin,
deltax=self.bounds.deltay,
ymin=self.xmin,
deltay=self.bounds.deltax,
)
return _Image(self.array.T.at[::-1, :].get(), bT, None)
[docs]
@implements(_galsim.Image.rot_ccw)
def rot_ccw(self):
bT = self.bounds.__class__(
xmin=self.ymin,
deltax=self.bounds.deltay,
ymin=self.xmin,
deltay=self.bounds.deltax,
)
return _Image(self.array.T.at[:, ::-1].get(), bT, None)
[docs]
@implements(_galsim.Image.rot_180)
def rot_180(self):
return _Image(self.array.at[::-1, ::-1].get(), self._bounds, None)
[docs]
def tree_flatten(self):
"""Flatten the image into a list of values."""
# Define the children nodes of the PyTree that need tracing
if self.bounds.isStatic():
children = (self.array, self.wcs)
aux_data = {
"dtype": self.dtype,
"bounds": self.bounds,
"isconst": self.isconst,
}
else:
children = (self.array, self.wcs, self.bounds)
aux_data = {"dtype": self.dtype, "isconst": self.isconst}
# other routines may add these attributes to images on the fly
# we have to include them here so that JAX knows how to handle them in jitting etc.
if hasattr(self, "added_flux"):
children += (self.added_flux,)
if hasattr(self, "header"):
aux_data["header"] = self.header
if hasattr(self, "photons"):
children += (self.photons,)
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
obj = object.__new__(cls)
obj._array = children[0]
obj.wcs = children[1]
if "bounds" in aux_data:
obj._bounds = aux_data["bounds"]
obj._dtype = aux_data["dtype"]
obj._is_const = aux_data["isconst"]
if len(children) > 2:
obj.added_flux = children[2]
if "header" in aux_data:
obj.header = aux_data["header"]
if len(children) > 3:
obj.photons = children[3]
else:
obj._bounds = children[2]
obj._dtype = aux_data["dtype"]
obj._is_const = aux_data["isconst"]
if len(children) > 3:
obj.added_flux = children[3]
if "header" in aux_data:
obj.header = aux_data["header"]
if len(children) > 4:
obj.photons = children[4]
return obj
[docs]
@classmethod
def from_galsim(cls, galsim_image):
"""Create a `Image` from a `galsim.Image` instance."""
wcs = (
BaseWCS.from_galsim(galsim_image.wcs)
if galsim_image.wcs is not None
else None
)
im = cls(
array=jnp.asarray(
cast_numpy_array_to_native_byte_order(galsim_image.array)
),
wcs=wcs,
bounds=Bounds.from_galsim(galsim_image.bounds),
)
if hasattr(galsim_image, "header"):
im.header = galsim_image.header
return im
[docs]
def to_galsim(self):
"""Create a galsim `Image` from a `jax_galsim.Image` object."""
wcs = self.wcs.to_galsim() if self.wcs is not None else None
return _galsim.Image(
np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs
)
[docs]
@implements(
_galsim.Image.FindAdaptiveMom,
lax_description=(
"This method converts the current `jax_galsim.Image` to a native "
"`galsim.Image` and delegates the computation to "
"`galsim.hsm.FindAdaptiveMom`. The returned object is GalSim's "
"`ShapeData`."
),
)
def FindAdaptiveMom(self, *args, **kwargs):
args_ = [arg.to_galsim() if hasattr(arg, "to_galsim") else arg for arg in args]
kwargs_ = {
key: val.to_galsim() if hasattr(val, "to_galsim") else val
for key, val in kwargs.items()
}
gs_image = self.to_galsim()
return gs_image.FindAdaptiveMom(*args_, **kwargs_)
@implements(
_galsim._Image,
lax_description=IMAGE_LAX_DOCS,
)
def _Image(array, bounds, wcs):
ret = Image.__new__(Image)
ret.wcs = wcs
ret._dtype = array.dtype.type
if ret._dtype in Image._alias_dtypes:
ret._dtype = Image._alias_dtypes[ret._dtype]
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
array = _safe_cast(array, jnp.issubdtype(ret._dtype, jnp.integer), ret._dtype)
ret._array = array
ret._bounds = bounds
return ret
# These are essentially aliases for the regular Image with the correct dtype
[docs]
@implements(
_galsim.ImageUS,
lax_description=IMAGE_LAX_DOCS,
)
def ImageUS(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.uint16)"""
kwargs["dtype"] = jnp.uint16
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageUI,
lax_description=IMAGE_LAX_DOCS,
)
def ImageUI(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.uint32)"""
kwargs["dtype"] = jnp.uint32
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageS,
lax_description=IMAGE_LAX_DOCS,
)
def ImageS(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.int16)"""
kwargs["dtype"] = jnp.int16
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageI,
lax_description=IMAGE_LAX_DOCS,
)
def ImageI(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.int32)"""
kwargs["dtype"] = jnp.int32
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageF,
lax_description=IMAGE_LAX_DOCS,
)
def ImageF(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.float32)"""
kwargs["dtype"] = jnp.float32
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageD,
lax_description=IMAGE_LAX_DOCS,
)
def ImageD(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.float64)"""
kwargs["dtype"] = jnp.float64
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageCF,
lax_description=IMAGE_LAX_DOCS,
)
def ImageCF(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.complex64)"""
kwargs["dtype"] = jnp.complex64
return Image(*args, **kwargs)
[docs]
@implements(
_galsim.ImageCD,
lax_description=IMAGE_LAX_DOCS,
)
def ImageCD(*args, **kwargs):
"""Alias for galsim.Image(..., dtype=numpy.complex128)"""
kwargs["dtype"] = jnp.complex128
return Image(*args, **kwargs)
################################################################################################
#
# Now we have to make some modifications to the C++ layer objects. Mostly adding some
# arithmetic functions, so they work more intuitively.
#
def _safe_cast(array, target_isinteger, target_dtype):
# code snippet pulled from upstream GalSim and turned into a general purpose
# function
#
# Assign the given array to self.array, safely casting it to the required type.
# Most important is to make sure integer types round first before casting, since
# numpy's astype doesn't do any rounding.
if target_isinteger:
array = jnp.around(array)
return array.astype(target_dtype)
# Define a utility function to be used by the arithmetic functions below
def check_image_consistency(im1, im2, integer=False):
if integer and not im1.isinteger:
raise _galsim.GalSimValueError("Image must have integer values.", im1)
if isinstance(im2, Image):
if im1.array.shape != im2.array.shape:
raise _galsim.GalSimIncompatibleValuesError(
"Image shapes are inconsistent", im1=im1, im2=im2
)
if integer and not im2.isinteger:
raise _galsim.GalSimValueError("Image must have integer values.", im2)
def Image_add(self, other):
check_image_consistency(self, other)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array + a, bounds=self.bounds, wcs=self.wcs)
def Image_iadd(self, other):
check_image_consistency(self, other)
try:
a = other.array
dt = a.dtype
except AttributeError:
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array.at[...].add(a)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array + a, self.isinteger, self.array.dtype)
)
return self
def Image_sub(self, other):
check_image_consistency(self, other)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array - a, bounds=self.bounds, wcs=self.wcs)
def Image_rsub(self, other):
return Image(other - self.array, bounds=self.bounds, wcs=self.wcs)
def Image_isub(self, other):
check_image_consistency(self, other)
try:
a = other.array
dt = a.dtype
except AttributeError:
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array.at[...].subtract(a)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array - a, self.isinteger, self.array.dtype)
)
return self
def Image_mul(self, other):
check_image_consistency(self, other)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array * a, bounds=self.bounds, wcs=self.wcs)
def Image_imul(self, other):
check_image_consistency(self, other)
try:
a = other.array
dt = a.dtype
except AttributeError:
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array.at[...].multiply(a)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array * a, self.isinteger, self.array.dtype)
)
return self
def Image_div(self, other):
check_image_consistency(self, other)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array / a, bounds=self.bounds, wcs=self.wcs)
def Image_rdiv(self, other):
return Image(other / self.array, bounds=self.bounds, wcs=self.wcs)
def Image_idiv(self, other):
check_image_consistency(self, other)
try:
a = other.array
dt = a.dtype
except AttributeError:
a = other
dt = type(a)
if dt == self.array.dtype and not self.isinteger:
# if dtype is an integer type, then numpy doesn't allow true division /= to assign
# back to an integer array. So for integers (or mixed types), don't use /=.
self._array = self.array.at[...].divide(a)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array / a, self.isinteger, self.array.dtype)
)
return self
def Image_floordiv(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array // a, bounds=self.bounds, wcs=self.wcs)
def Image_rfloordiv(self, other):
check_image_consistency(self, other, integer=True)
return Image(other // self.array, bounds=self.bounds, wcs=self.wcs)
def Image_ifloordiv(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
dt = a.dtype
except AttributeError:
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array.at[...].set(self.array // a)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array // a, self.isinteger, self.array.dtype)
)
return self
def Image_mod(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array % a, bounds=self.bounds, wcs=self.wcs)
def Image_rmod(self, other):
check_image_consistency(self, other, integer=True)
return Image(other % self.array, bounds=self.bounds, wcs=self.wcs)
def Image_imod(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
dt = a.dtype
except AttributeError:
a = other
dt = type(a)
if dt == self.array.dtype:
self._array = self.array.at[...].set(self.array % a)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array % a, self.isinteger, self.array.dtype)
)
return self
def Image_pow(self, other):
return Image(self.array**other, bounds=self.bounds, wcs=self.wcs)
def Image_ipow(self, other):
if not isinstance(other, int) and not isinstance(other, float):
raise TypeError("Can only raise an image to a float or int power!")
if not self.isinteger or isinstance(other, int):
self._array = self.array.at[...].power(other)
else:
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
self._array = self.array.at[...].set(
_safe_cast(self.array**other, self.isinteger, self.array.dtype)
)
return self
def Image_neg(self):
result = self.copy()
result *= -1
return result
# Define &, ^ and | only for integer-type images
def Image_and(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array & a, bounds=self.bounds, wcs=self.wcs)
def Image_iand(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
self._array = self.array.at[...].set(self.array & a)
return self
def Image_xor(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array ^ a, bounds=self.bounds, wcs=self.wcs)
def Image_ixor(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
self._array = self.array.at[...].set(self.array ^ a)
return self
def Image_or(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
return Image(self.array | a, bounds=self.bounds, wcs=self.wcs)
def Image_ior(self, other):
check_image_consistency(self, other, integer=True)
try:
a = other.array
except AttributeError:
a = other
self._array = self.array.at[...].set(self.array | a)
return self
# inject the arithmetic operators as methods of the Image class:
Image.__add__ = Image_add
Image.__radd__ = Image_add
Image.__iadd__ = Image_iadd
Image.__sub__ = Image_sub
Image.__rsub__ = Image_rsub
Image.__isub__ = Image_isub
Image.__mul__ = Image_mul
Image.__rmul__ = Image_mul
Image.__imul__ = Image_imul
Image.__div__ = Image_div
Image.__rdiv__ = Image_rdiv
Image.__truediv__ = Image_div
Image.__rtruediv__ = Image_rdiv
Image.__idiv__ = Image_idiv
Image.__itruediv__ = Image_idiv
Image.__mod__ = Image_mod
Image.__rmod__ = Image_rmod
Image.__imod__ = Image_imod
Image.__floordiv__ = Image_floordiv
Image.__rfloordiv__ = Image_rfloordiv
Image.__ifloordiv__ = Image_ifloordiv
Image.__ipow__ = Image_ipow
Image.__pow__ = Image_pow
Image.__neg__ = Image_neg
Image.__and__ = Image_and
Image.__xor__ = Image_xor
Image.__or__ = Image_or
Image.__rand__ = Image_and
Image.__rxor__ = Image_xor
Image.__ror__ = Image_or
Image.__iand__ = Image_iand
Image.__ixor__ = Image_ixor
Image.__ior__ = Image_ior