Source code for jax_galsim.image

import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.bounds import Bounds, BoundsD, BoundsI
from jax_galsim.core.utils import (
    cast_numpy_array_to_native_byte_order,
    ensure_hashable,
    has_tracers,
    implements,
)
from jax_galsim.errors import GalSimImmutableError
from jax_galsim.position import PositionI
from jax_galsim.utilities import parse_pos_args
from jax_galsim.wcs import BaseWCS, PixelScale

IMAGE_LAX_DOCS = """\
Contrary to GalSim native Image, this implementation does not support
sharing of the underlying array between different Images or views.
This is due to the fact that in JAX numpy arrays are immutable, so any
operation applied to this Image will create a new ``jnp.ndarray``. Making
a view via ``.view()`` will raise an error. Instead, use the ``.copy()``
method. The ``Image.subImage()`` method will return a copy.
"""


[docs] @implements( _galsim.Image, lax_description=IMAGE_LAX_DOCS, ) @register_pytree_node_class class Image(object): _alias_dtypes = { int: jnp.int32, # So that user gets what they would expect float: jnp.float64, # if using dtype=int or float or complex complex: jnp.complex128, jnp.int64: jnp.int32, # Not equivalent, but will convert np.uint16: jnp.uint16, np.uint32: jnp.uint32, np.int16: jnp.int16, np.int32: jnp.int32, np.float32: jnp.float32, np.float64: jnp.float64, np.complex64: jnp.complex64, np.complex128: jnp.complex128, } _valid_dtypes = [ jnp.int32, jnp.float64, jnp.uint16, jnp.uint32, jnp.int16, jnp.float32, jnp.complex64, jnp.complex128, ] valid_dtypes = _valid_dtypes def __init__(self, *args, **kwargs): # this one is specific to jax-galsim and is used to disable bounds checking # we use an underscore to denote that it is a private argument _check_bounds = kwargs.pop("_check_bounds", True) # Parse the args, kwargs ncol = None nrow = None bounds = None array = None image = None if len(args) > 2: raise TypeError("Error, too many unnamed arguments to Image constructor") elif len(args) == 2: ncol = args[0] nrow = args[1] xmin = kwargs.pop("xmin", 1) ymin = kwargs.pop("ymin", 1) elif len(args) == 1: if isinstance(args[0], np.ndarray): array = jnp.array(cast_numpy_array_to_native_byte_order(args[0])) array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) elif isinstance(args[0], jnp.ndarray): array = args[0] array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) elif isinstance(args[0], BoundsI): bounds = args[0] elif isinstance(args[0], (list, tuple)): array = jnp.array(args[0]) array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) elif isinstance(args[0], Image): image = args[0] else: raise TypeError( "Unable to parse %s as an array, bounds, or image." % args[0] ) else: if "array" in kwargs: array = kwargs.pop("array") array, xmin, ymin = self._get_xmin_ymin( array, kwargs, check_bounds=_check_bounds ) elif "bounds" in kwargs: bounds = kwargs.pop("bounds") elif "image" in kwargs: image = kwargs.pop("image") else: ncol = kwargs.pop("ncol", None) nrow = kwargs.pop("nrow", None) xmin = kwargs.pop("xmin", 1) ymin = kwargs.pop("ymin", 1) # Pop off the other valid kwargs: dtype = kwargs.pop("dtype", None) init_value = kwargs.pop("init_value", None) scale = kwargs.pop("scale", None) wcs = kwargs.pop("wcs", None) self._is_const = kwargs.pop("make_const", False) # Check that we got them all if kwargs: if "copy" in kwargs.keys() and not kwargs["copy"]: raise TypeError( "'copy=False' is not a valid keyword argument for the JAX-GalSim version of the Image constructor" ) else: # remove it since we used it kwargs.pop("copy", None) if kwargs: raise TypeError( "Image constructor got unexpected keyword arguments: %s", kwargs ) # Figure out what dtype we want: dtype = self._alias_dtypes.get(dtype, dtype) if dtype is not None and dtype not in self._valid_dtypes: raise _galsim.GalSimValueError("Invlid dtype.", dtype, self._valid_dtypes) if array is not None: if dtype is None: dtype = array.dtype.type if dtype in self._alias_dtypes: dtype = self._alias_dtypes[dtype] # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype) elif dtype not in self._valid_dtypes: raise _galsim.GalSimValueError( "Invalid dtype of provided array.", array.dtype, self._valid_dtypes, ) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype) # Be careful here: we have to watch out for little-endian / big-endian issues. # The path of least resistance is to check whether the array.dtype is equal to the # native one (using the dtype.isnative flag), and if not, make a new array that has a # type equal to the same one but with the appropriate endian-ness. if not array.dtype.isnative: array = array.astype(array.dtype.newbyteorder("=")) self._dtype = array.dtype.type elif image is not None: if not isinstance(image, Image): raise TypeError("image must be an Image") # we do less checking here since we already have a valid image if dtype is None: self._dtype = image.dtype else: self._dtype = dtype elif dtype is not None: self._dtype = dtype else: self._dtype = jnp.float32 # Construct the image attribute if ncol is not None or nrow is not None: if ncol is None or nrow is None: raise _galsim.GalSimIncompatibleValuesError( "Both nrow and ncol must be provided", ncol=ncol, nrow=nrow ) if ncol != int(ncol) or nrow != int(nrow): raise TypeError("nrow, ncol must be integers") ncol = int(ncol) nrow = int(nrow) self._array = self._make_empty(shape=(nrow, ncol), dtype=self._dtype) if not has_tracers(xmin) and not has_tracers(ymin): self._bounds = BoundsI( xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True ) else: self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value: self._array = self._array.at[...].add(init_value) elif bounds is not None: if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") self._array = self._make_empty(bounds.numpyShape(), dtype=self._dtype) self._bounds = bounds if init_value: self._array = self._array.at[...].add(init_value) elif array is not None: self._array = array.view() nrow, ncol = array.shape if not has_tracers(xmin) and not has_tracers(ymin): self._bounds = BoundsI( xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow, static=True ) else: self._bounds = BoundsI(xmin=xmin, deltax=ncol, ymin=ymin, deltay=nrow) if init_value is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot specify init_value with array", init_value=init_value, array=array, ) elif image is not None: if not isinstance(image, Image): raise TypeError("image must be an Image") if init_value is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot specify init_value with image", init_value=init_value, image=image, ) if wcs is None and scale is None: wcs = image.wcs self._bounds = image.bounds if dtype is None: self._dtype = image.dtype else: # Allow dtype to force a retyping of the provided image # e.g. im = ImageF(...) # im2 = ImageD(im) self._dtype = dtype # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = _safe_cast( image.array, jnp.issubdtype(self._dtype, jnp.integer), self._dtype ) else: self._array = jnp.zeros(shape=(1, 1), dtype=self._dtype) self._bounds = BoundsI() if init_value is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot specify init_value without setting an initial size", init_value=init_value, ncol=ncol, nrow=nrow, bounds=bounds, ) # Construct the wcs attribute if scale is not None: if wcs is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot provide both scale and wcs to Image constructor", wcs=wcs, scale=scale, ) self.wcs = PixelScale(float(scale)) else: if wcs is not None and not isinstance(wcs, BaseWCS): raise TypeError("wcs parameters must be a galsim.BaseWCS instance") self.wcs = wcs @staticmethod def _get_xmin_ymin(array, kwargs, check_bounds=True): """A helper function for parsing xmin, ymin, bounds options with a given array""" if not isinstance(array, (np.ndarray, jnp.ndarray)): raise TypeError("array must be a ndarray instance") xmin = kwargs.pop("xmin", 1) ymin = kwargs.pop("ymin", 1) if "bounds" in kwargs: b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") if ( check_bounds and b.isDefined() and not has_tracers(b.xmin) and not has_tracers(b.ymin) and not has_tracers(b.xmax) and not has_tracers(b.ymax) ): # We need to disable this when jitting if b.xmax - b.xmin + 1 != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( "Shape of array is inconsistent with provided bounds", array=array, bounds=b, ) if b.ymax - b.ymin + 1 != array.shape[0]: raise _galsim.GalSimIncompatibleValuesError( "Shape of array is inconsistent with provided bounds", array=array, bounds=b, ) if b.isDefined(): xmin = b.xmin ymin = b.ymin else: # Indication that array is formally undefined, even though provided. if "dtype" not in kwargs: kwargs["dtype"] = array.dtype.type array = None xmin = None ymin = None elif array.shape[1] == 0: # Another way to indicate that we don't have a defined image. if "dtype" not in kwargs: kwargs["dtype"] = array.dtype.type array = None xmin = None ymin = None return array, xmin, ymin def __repr__(self): s = "galsim.Image(bounds=%r" % self.bounds if self.bounds.isDefined(): s += ", array=\n%r" % (ensure_hashable(np.array(self.array)),) s += ", wcs=%r" % self.wcs if self.isconst: s += ", make_const=True" s += ")" return s def __str__(self): # Get the type name without the <type '...'> part. t = str(self.dtype).split("'")[1] if self.wcs is not None and self.wcs._isPixelScale: return "galsim.Image(bounds=%s, scale=%s, dtype=%s)" % ( self.bounds, ensure_hashable(self.scale), t, ) else: return "galsim.Image(bounds=%s, wcs=%s, dtype=%s)" % ( self.bounds, self.wcs, t, ) # Not immutable object. So shouldn't be used as a hash. __hash__ = None # Read-only attributes: @property @implements(_galsim.Image.dtype) def dtype(self): return self._dtype @property @implements(_galsim.Image.bounds) def bounds(self): return self._bounds @property @implements(_galsim.Image.array) def array(self): return self._array @array.setter def array(self, other): # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self._array.at[...].set( _safe_cast(other, self.isinteger, self.array.dtype) ) @property @implements(_galsim.Image.nrow) def nrow(self): return self._array.shape[0] @property @implements(_galsim.Image.ncol) def ncol(self): return self._array.shape[1] @property @implements(_galsim.Image.isconst) def isconst(self): return self._is_const @property @implements(_galsim.Image.iscomplex) def iscomplex(self): return self._array.dtype.kind == "c" @property @implements(_galsim.Image.isinteger) def isinteger(self): return self._array.dtype.kind in ("i", "u") @property @implements( _galsim.Image.iscontiguous, lax_description="In JAX all arrays are contiguous." ) def iscontiguous(self): return True # In JAX all arrays are contiguous (almost) # Allow scale to work as a PixelScale wcs. @property @implements(_galsim.Image.scale) def scale(self): try: return self.wcs.scale except Exception: if self.wcs: raise _galsim.GalSimError( "image.wcs is not a simple PixelScale; scale is undefined." ) else: return None @scale.setter def scale(self, value): if self.wcs is not None and not self.wcs._isPixelScale: raise _galsim.GalSimError( "image.wcs is not a simple PixelScale; scale is undefined." ) else: self.wcs = PixelScale(value) # Convenience functions @property @implements(_galsim.Image.xmin) def xmin(self): return self._bounds.xmin @property @implements(_galsim.Image.xmax) def xmax(self): return self._bounds.xmax @property @implements(_galsim.Image.ymin) def ymin(self): return self._bounds.ymin @property @implements(_galsim.Image.ymax) def ymax(self): return self._bounds.ymax @property @implements(_galsim.Image.outer_bounds) def outer_bounds(self): return BoundsD( self.xmin - 0.5, self.xmax + 0.5, self.ymin - 0.5, self.ymax + 0.5 ) # real, imag for everything, even real images. @property @implements(_galsim.Image.real) def real(self): return self.__class__( self.array.real, bounds=self.bounds, wcs=self.wcs, make_const=self._is_const ) @property @implements(_galsim.Image.imag) def imag(self): return self.__class__( self.array.imag, bounds=self.bounds, wcs=self.wcs, # for real images, the imaginary part is always zero and immutable make_const=self._is_const or (not self.iscomplex), ) @property @implements(_galsim.Image.conjugate) def conjugate(self): return self.__class__(self.array.conjugate(), bounds=self.bounds, wcs=self.wcs)
[docs] @implements( _galsim.Image.copy, lax_description=( "JAX-GalSim supports extra keyword arguments to ``.copy`` so " "that users can make copies of images while also changing the image " "properties (e.g., the wcs). The extra keywords behave exactly like " "those of ``Image.view``." ), ) def copy( self, scale=None, wcs=None, origin=None, center=None, dtype=None, make_const=False, contiguous=False, ): if origin is not None and center is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot provide both center and origin", center=center, origin=origin ) if scale is not None: if wcs is not None: raise _galsim.GalSimIncompatibleValuesError( "Cannot provide both scale and wcs", scale=scale, wcs=wcs ) wcs = PixelScale(scale) elif wcs is not None: if not isinstance(wcs, BaseWCS): raise TypeError("wcs parameters must be a galsim.BaseWCS instance") else: wcs = self.wcs # Figure out the dtype for the return Image dtype = dtype if dtype else self.dtype # If currently empty, just return a new empty image. if not self.bounds.isDefined(): return Image(wcs=wcs, dtype=dtype, make_const=make_const) # Recast the array type if necessary array = self.array.copy() if dtype != array.dtype: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed array = _safe_cast(array, jnp.issubdtype(dtype, jnp.integer), dtype) elif contiguous: # this is a noop since all jax arrays are contiguous pass else: # do nothing here since we made copy above pass # Make the return Image - already made copy above ret = self.__class__(array, bounds=self.bounds, wcs=wcs, make_const=make_const) # Update the origin if requested if origin is not None: ret.setOrigin(origin) elif center is not None: ret.setCenter(center) return ret
[docs] @implements(_galsim.Image.get_pixel_centers) def get_pixel_centers(self): x, y = jnp.meshgrid( jnp.arange(self.array.shape[1], dtype=float), jnp.arange(self.array.shape[0], dtype=float), ) x += self.bounds.xmin y += self.bounds.ymin return x, y
def _make_empty(self, shape, dtype): """Helper function to make an empty numpy array of the given shape.""" if np.prod(shape) == 0: # galsim forces degenerate images to have at least 1 pixel return jnp.zeros(shape=(1, 1), dtype=dtype) else: return jnp.zeros(shape=shape, dtype=dtype)
[docs] @implements(_galsim.Image.resize) def resize(self, bounds, wcs=None): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) self._bounds = bounds if wcs is not None: self.wcs = wcs
[docs] @implements(_galsim.Image.subImage) def subImage(self, bounds): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" ) if ( not has_tracers(self.bounds.xmin) and not has_tracers(self.bounds.xmax) and not has_tracers(self.bounds.ymin) and not has_tracers(self.bounds.ymax) and not has_tracers(bounds.xmin) and not has_tracers(bounds.xmax) and not has_tracers(bounds.ymin) and not has_tracers(bounds.ymax) and not self.bounds.includes(bounds) ): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin i2 = bounds.ymax - self.ymin + 1 j1 = bounds.xmin - self.xmin j2 = bounds.xmax - self.xmin + 1 subarray = self.array[i1:i2, j1:j2] else: start_inds = ( bounds.ymin - self.ymin, bounds.xmin - self.xmin, ) shape = bounds.numpyShape() subarray = jax.lax.dynamic_slice(self.array, start_inds, shape) # NB. The wcs is still accurate, since the sub-image uses the same (x,y) values # as the original image did for those pixels. It's only once you recenter or # reorigin that you need to update the wcs. So that's taken care of in im.shift. return self.__class__(subarray, bounds=bounds, wcs=self.wcs)
[docs] @implements(_galsim.Image.setSubImage) def setSubImage(self, bounds, rhs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) if ( not has_tracers(self.bounds.xmin) and not has_tracers(self.bounds.xmax) and not has_tracers(self.bounds.ymin) and not has_tracers(self.bounds.ymax) and not has_tracers(bounds.xmin) and not has_tracers(bounds.xmax) and not has_tracers(bounds.ymin) and not has_tracers(bounds.ymax) and not self.bounds.includes(bounds) ): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") if bounds.numpyShape() != rhs.bounds.numpyShape(): raise _galsim.GalSimIncompatibleValuesError( "Trying to copy images that are not the same shape", self_image=self, rhs=rhs, ) if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin i2 = bounds.ymax - self.ymin + 1 j1 = bounds.xmin - self.xmin j2 = bounds.xmax - self.xmin + 1 # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self._array.at[i1:i2, j1:j2].set( _safe_cast( rhs.array, jnp.issubdtype(self.dtype, jnp.integer), self.dtype ) ) else: start_inds = ( bounds.ymin - self.ymin, bounds.xmin - self.xmin, ) self._array = jax.lax.dynamic_update_slice( self.array, # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed _safe_cast( rhs.array, jnp.issubdtype(self.dtype, jnp.integer), self.dtype ), start_inds, )
def __getitem__(self, *args): """Return either a subimage or a single pixel value. For example,:: >>> subimage = im[galsim.BoundsI(3,7,3,7)] >>> value = im[galsim.PositionI(5,5)] >>> value = im[5,5] """ if len(args) == 1: if isinstance(args[0], BoundsI): return self.subImage(*args) elif isinstance(args[0], PositionI): return self(*args) elif isinstance(args[0], tuple): return self.getValue(*args[0]) else: raise TypeError( "image[index] only accepts BoundsI or PositionI for the index" ) elif len(args) == 2: return self(*args) else: raise TypeError("image[..] requires either 1 or 2 args") def __setitem__(self, *args): """Set either a subimage or a single pixel to new values. For example,:: >>> im[galsim.BoundsI(3,7,3,7)] = im2 >>> im[galsim.PositionI(5,5)] = 17. >>> im[5,5] = 17. """ if len(args) == 2: if isinstance(args[0], BoundsI): self.setSubImage(*args) elif isinstance(args[0], PositionI): self.setValue(*args) elif isinstance(args[0], tuple): self.setValue(*args) else: raise TypeError( "image[index] only accepts BoundsI or PositionI for the index" ) elif len(args) == 3: return self.setValue(*args) else: raise TypeError("image[..] requires either 1 or 2 args")
[docs] @implements(_galsim.Image.wrap) def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") # Get this at the start to check for invalid bounds and raise the exception before # possibly writing data past the edge of the image. if not hermitian: return self._wrap(bounds, False, False, None) elif hermitian == "x": if not has_tracers(self.bounds.xmin) and self.bounds.xmin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'x' requires self.bounds.xmin == 0", hermitian=hermitian, bounds=self.bounds, ) if not has_tracers(bounds.xmin) and bounds.xmin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'x' requires bounds.xmin == 0", hermitian=hermitian, bounds=bounds, ) return self._wrap(bounds, True, False, 2 * bounds.xmax) elif hermitian == "y": if not has_tracers(self.bounds.ymin) and self.bounds.ymin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'y' requires self.bounds.ymin == 0", hermitian=hermitian, bounds=self.bounds, ) if not has_tracers(bounds.ymin) and bounds.ymin != 0: raise _galsim.GalSimIncompatibleValuesError( "hermitian == 'y' requires bounds.ymin == 0", hermitian=hermitian, bounds=bounds, ) return self._wrap(bounds, False, True, 2 * bounds.ymax) else: raise _galsim.GalSimValueError( "Invalid value for hermitian", hermitian, (False, "x", "y") )
@implements(_galsim.Image._wrap) def _wrap(self, bounds, hermx, hermy, hermitian_wrap_size): if not hermx and not hermy: from jax_galsim.core.wrap_image import wrap_nonhermitian self._array = self._array.at[...].set( wrap_nonhermitian( self._array, # zero indexed location of subimage bounds.xmin - self.xmin, bounds.ymin - self.ymin, bounds.deltax, bounds.deltay, ) ) elif hermx and not hermy: from jax_galsim.core.wrap_image import wrap_hermitian_x self._array = self._array.at[...].set( wrap_hermitian_x( self._array, -self.xmax, self.ymin, -bounds.xmax + 1, bounds.ymin, hermitian_wrap_size, bounds.deltay, ) ) elif not hermx and hermy: from jax_galsim.core.wrap_image import wrap_hermitian_y self._array = self._array.at[...].set( wrap_hermitian_y( self._array, self.xmin, -self.ymax, bounds.xmin, -bounds.ymax + 1, bounds.deltax, hermitian_wrap_size, ) ) return self.subImage(bounds)
[docs] @implements( _galsim.Image.calculate_fft, lax_description="JAX-GalSim does not support forward FFTs of complex dtypes.", ) def calculate_fft(self): if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "calculate_fft requires that the image have defined bounds." ) if self.wcs is None: raise _galsim.GalSimError("calculate_fft requires that the scale be set.") if not self.wcs._isPixelScale: raise _galsim.GalSimError( "calculate_fft requires that the image has a PixelScale wcs." ) if self.dtype in [np.complex64, np.complex128, complex]: raise _galsim.GalSimNotImplementedError( "JAX-GalSim does not support forward FFTs of complex dtypes." ) # TODO: figure out how to do FFT at fixed size and then reconstruct # the result No2 = max( max( -self.bounds.xmin, self.bounds.xmax + 1, ), max( -self.bounds.ymin, self.bounds.ymax + 1, ), ) full_bounds = BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2) if self.bounds == full_bounds: # Then the image is already in the shape we need. ximage = self else: # Then we pad out with zeros ximage = Image(full_bounds, dtype=self.dtype, init_value=0) ximage[self.bounds] = self[self.bounds] dx = self.scale # dk = 2pi / (N dk) dk = jnp.pi / (No2 * dx) out = Image( BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2), dtype=np.complex128, scale=dk, ) # we shift the image before and after the FFT to match the layout of the modes # used by GalSim out._array = out._array.at[...].set( jnp.fft.fftshift(jnp.fft.rfft2(jnp.fft.fftshift(ximage.array)), axes=0) ) out *= dx * dx out.setOrigin(0, -No2) return out
[docs] @implements(_galsim.Image.calculate_inverse_fft) def calculate_inverse_fft(self): if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "calculate_fft requires that the image have defined bounds." ) if self.wcs is None: raise _galsim.GalSimError( "calculate_inverse_fft requires that the scale be set." ) if not self.wcs._isPixelScale: raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) if not self.bounds.includes(0, 0): raise _galsim.GalSimBoundsError( "calculate_inverse_fft requires that the image includes (0,0)", PositionI(0, 0), self.bounds, ) No2 = max( max(self.bounds.xmax, -self.bounds.ymin), self.bounds.ymax, ) target_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2) if self.bounds == target_bounds: # Then the image is already in the shape we need. kimage = self else: # Then we can pad out with zeros and wrap to get this in the form we need. full_bounds = BoundsI(xmin=0, deltax=No2 + 1, ymin=-No2, deltay=2 * No2 + 1) kimage = Image(full_bounds, dtype=self.dtype, init_value=0) posx_bounds = BoundsI( xmin=0, xmax=self.bounds.xmax, ymin=self.bounds.ymin, ymax=self.bounds.ymax, ) kimage[posx_bounds] = self[posx_bounds] kimage = kimage._wrap(target_bounds, True, False, 2 * No2) dk = self.scale # dx = 2pi / (N dk) dx = jnp.pi / (No2 * dk) # In GalSim, they use inplace FFTW transforms which require the # array that holds the input/output to have extra padding on the # x dimension. # jax-galsim does not need the padding since it does not use an # inplace FFT. Thus we do not use the # padding. out = Image( bounds=BoundsI(xmin=-No2, deltax=2 * No2, ymin=-No2, deltay=2 * No2), dtype=float, scale=dx, # we shift the image before and after the FFT to match the layout used by galsim array=jnp.fft.fftshift( jnp.fft.irfft2(jnp.fft.fftshift(kimage.array, axes=0)) ) * (dk * No2 / jnp.pi) ** 2, ) out.setCenter(0, 0) return out
[docs] @classmethod @implements(_galsim.Image.good_fft_size) def good_fft_size(cls, input_size): # we use the math module here since this function should not be jitted. import math # Reference from GalSim C++ # https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/Image.cpp#L1009 # Reduce slightly to eliminate potential rounding errors: insize = (1.0 - 1.0e-5) * input_size log2n = math.log(2.0) * math.ceil(math.log(insize) / math.log(2.0)) log2n3 = math.log(3.0) + math.log(2.0) * math.ceil( (math.log(insize) - math.log(3.0)) / math.log(2.0) ) log2n3 = max(log2n3, math.log(6.0)) # must be even number Nk = max(int(math.ceil(math.exp(min(log2n, log2n3)) - 1.0e-5)), 2) return Nk
[docs] @implements(_galsim.Image.copyFrom) def copyFrom(self, rhs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") if self.bounds.numpyShape() != rhs.bounds.numpyShape(): raise _galsim.GalSimIncompatibleValuesError( "Trying to copy images that are not the same shape", self_image=self, rhs=rhs, ) self._copyFrom(rhs)
def _copyFrom(self, rhs): """Same as copyFrom, but no sanity checks.""" self._array = self._array.at[...].set( # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed _safe_cast(rhs._array, self.isinteger, self.array.dtype) )
[docs] @implements( _galsim.Image.view, lax_description=( "JAX-GalSim does not support image views. This " "method will raise an error if called." ), ) def view( self, scale=None, wcs=None, origin=None, center=None, dtype=None, make_const=False, contiguous=False, ): raise NotImplementedError( "JAX-GalSim does not support views of images! Use ``.copy`` instead." )
[docs] @implements(_galsim.Image.shift) def shift(self, *args, **kwargs): delta = parse_pos_args(args, kwargs, "dx", "dy", integer=True) self._shift(delta)
@implements(_galsim.Image._shift) def _shift(self, delta): self._bounds = self._bounds.shift(delta) if self.wcs is not None: self.wcs = self.wcs.shiftOrigin(delta)
[docs] @implements(_galsim.Image.setCenter) def setCenter(self, *args, **kwargs): cen = parse_pos_args(args, kwargs, "xcen", "ycen", integer=True) self._shift(cen - self.center)
[docs] @implements(_galsim.Image.setOrigin) def setOrigin(self, *args, **kwargs): origin = parse_pos_args(args, kwargs, "x0", "y0", integer=True) self._shift(origin - self.origin)
@property @implements(_galsim.Image.center) def center(self): return self.bounds.center @property @implements(_galsim.Image.true_center) def true_center(self): return self.bounds.true_center @property @implements(_galsim.Image.origin) def origin(self): return self.bounds.origin def __call__(self, *args, **kwargs): """Get the pixel value at given position The arguments here may be either (x, y) or a PositionI instance. Or you can provide x, y as named kwargs. """ pos = parse_pos_args(args, kwargs, "x", "y", integer=True) return self.getValue(pos.x, pos.y)
[docs] @implements(_galsim.Image.getValue) def getValue(self, x, y): if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) if not self.bounds.includes(x, y): raise _galsim.GalSimBoundsError( "Attempt to access position not in bounds of image.", PositionI(x, y), self.bounds, ) return self._getValue(x, y)
@implements(_galsim.Image._getValue) def _getValue(self, x, y): return self.array[y - self.ymin, x - self.xmin]
[docs] @implements(_galsim.Image.setValue) def setValue(self, *args, **kwargs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set value of an undefined image" ) pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) if not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) self._setValue(pos.x, pos.y, value)
@implements(_galsim.Image._setValue) def _setValue(self, x, y, value): self._array = self._array.at[y - self.ymin, x - self.xmin].set(value)
[docs] @implements(_galsim.Image.addValue) def addValue(self, *args, **kwargs): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set value of an undefined image" ) pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) if not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) self._addValue(pos.x, pos.y, value)
@implements(_galsim.Image._addValue) def _addValue(self, x, y, value): self._array = self._array.at[y - self.ymin, x - self.xmin].add(value)
[docs] @implements(_galsim.Image.fill) def fill(self, value): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set values of an undefined image" ) self._fill(value)
@implements(_galsim.Image._fill) def _fill(self, value): self._array = self._array.at[...].set(value)
[docs] @implements(_galsim.Image.setZero) def setZero(self): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) self._fill(0)
[docs] @implements(_galsim.Image.invertSelf) def invertSelf(self): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to set values of an undefined image" ) self._invertSelf()
@implements(_galsim.Image._invertSelf) def _invertSelf(self): msk = self._array == 0 safe_array = jnp.where( msk, 1.0, self._array, ) self._array = self._array.at[...].set( # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed _safe_cast( (jnp.where(msk, 0.0, 1.0 / safe_array)), jnp.issubdtype(self._array.dtype, jnp.integer), self._array.dtype, ) )
[docs] @implements(_galsim.Image.replaceNegative) def replaceNegative(self, replace_value=0): if self.isconst: raise GalSimImmutableError("Cannot modify an immutable Image", self) self._array = self.array.at[self.array < 0].set(replace_value)
def __eq__(self, other): # Note that numpy.array_equal can return True if the dtypes of the two arrays involved are # different, as long as the contents of the two arrays are logically the same. For example: # # >>> double_array = np.arange(1024).reshape(32, 32)*np.pi # >>> int_array = np.arange(1024).reshape(32, 32) # >>> assert galsim.ImageD(int_array) == galsim.ImageF(int_array) # passes # >>> assert galsim.ImageD(double_array) == galsim.ImageF(double_array) # fails return self is other or ( isinstance(other, Image) and self.bounds == other.bounds and self.wcs == other.wcs and ( not self.bounds.isDefined() or jnp.array_equal(self.array, other.array) ) and self.isconst == other.isconst ) def __ne__(self, other): return not self.__eq__(other)
[docs] @implements(_galsim.Image.transpose) def transpose(self): bT = self.bounds.__class__( xmin=self.ymin, deltax=self.bounds.deltay, ymin=self.xmin, deltay=self.bounds.deltax, ) return _Image(self.array.T, bT, None)
[docs] @implements(_galsim.Image.flip_lr) def flip_lr(self): return _Image(self.array.at[:, ::-1].get(), self._bounds, None)
[docs] @implements(_galsim.Image.flip_ud) def flip_ud(self): return _Image(self.array.at[::-1, :].get(), self._bounds, None)
[docs] @implements(_galsim.Image.rot_cw) def rot_cw(self): bT = self.bounds.__class__( xmin=self.ymin, deltax=self.bounds.deltay, ymin=self.xmin, deltay=self.bounds.deltax, ) return _Image(self.array.T.at[::-1, :].get(), bT, None)
[docs] @implements(_galsim.Image.rot_ccw) def rot_ccw(self): bT = self.bounds.__class__( xmin=self.ymin, deltax=self.bounds.deltay, ymin=self.xmin, deltay=self.bounds.deltax, ) return _Image(self.array.T.at[:, ::-1].get(), bT, None)
[docs] @implements(_galsim.Image.rot_180) def rot_180(self): return _Image(self.array.at[::-1, ::-1].get(), self._bounds, None)
[docs] def tree_flatten(self): """Flatten the image into a list of values.""" # Define the children nodes of the PyTree that need tracing if self.bounds.isStatic(): children = (self.array, self.wcs) aux_data = { "dtype": self.dtype, "bounds": self.bounds, "isconst": self.isconst, } else: children = (self.array, self.wcs, self.bounds) aux_data = {"dtype": self.dtype, "isconst": self.isconst} # other routines may add these attributes to images on the fly # we have to include them here so that JAX knows how to handle them in jitting etc. if hasattr(self, "added_flux"): children += (self.added_flux,) if hasattr(self, "header"): aux_data["header"] = self.header if hasattr(self, "photons"): children += (self.photons,) return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" obj = object.__new__(cls) obj._array = children[0] obj.wcs = children[1] if "bounds" in aux_data: obj._bounds = aux_data["bounds"] obj._dtype = aux_data["dtype"] obj._is_const = aux_data["isconst"] if len(children) > 2: obj.added_flux = children[2] if "header" in aux_data: obj.header = aux_data["header"] if len(children) > 3: obj.photons = children[3] else: obj._bounds = children[2] obj._dtype = aux_data["dtype"] obj._is_const = aux_data["isconst"] if len(children) > 3: obj.added_flux = children[3] if "header" in aux_data: obj.header = aux_data["header"] if len(children) > 4: obj.photons = children[4] return obj
[docs] @classmethod def from_galsim(cls, galsim_image): """Create a `Image` from a `galsim.Image` instance.""" wcs = ( BaseWCS.from_galsim(galsim_image.wcs) if galsim_image.wcs is not None else None ) im = cls( array=jnp.asarray( cast_numpy_array_to_native_byte_order(galsim_image.array) ), wcs=wcs, bounds=Bounds.from_galsim(galsim_image.bounds), ) if hasattr(galsim_image, "header"): im.header = galsim_image.header return im
[docs] def to_galsim(self): """Create a galsim `Image` from a `jax_galsim.Image` object.""" wcs = self.wcs.to_galsim() if self.wcs is not None else None return _galsim.Image( np.asarray(self.array), bounds=self.bounds.to_galsim(), wcs=wcs )
[docs] @implements( _galsim.Image.FindAdaptiveMom, lax_description=( "This method converts the current `jax_galsim.Image` to a native " "`galsim.Image` and delegates the computation to " "`galsim.hsm.FindAdaptiveMom`. The returned object is GalSim's " "`ShapeData`." ), ) def FindAdaptiveMom(self, *args, **kwargs): args_ = [arg.to_galsim() if hasattr(arg, "to_galsim") else arg for arg in args] kwargs_ = { key: val.to_galsim() if hasattr(val, "to_galsim") else val for key, val in kwargs.items() } gs_image = self.to_galsim() return gs_image.FindAdaptiveMom(*args_, **kwargs_)
@implements( _galsim._Image, lax_description=IMAGE_LAX_DOCS, ) def _Image(array, bounds, wcs): ret = Image.__new__(Image) ret.wcs = wcs ret._dtype = array.dtype.type if ret._dtype in Image._alias_dtypes: ret._dtype = Image._alias_dtypes[ret._dtype] # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed array = _safe_cast(array, jnp.issubdtype(ret._dtype, jnp.integer), ret._dtype) ret._array = array ret._bounds = bounds return ret # These are essentially aliases for the regular Image with the correct dtype
[docs] @implements( _galsim.ImageUS, lax_description=IMAGE_LAX_DOCS, ) def ImageUS(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.uint16)""" kwargs["dtype"] = jnp.uint16 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageUI, lax_description=IMAGE_LAX_DOCS, ) def ImageUI(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.uint32)""" kwargs["dtype"] = jnp.uint32 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageS, lax_description=IMAGE_LAX_DOCS, ) def ImageS(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.int16)""" kwargs["dtype"] = jnp.int16 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageI, lax_description=IMAGE_LAX_DOCS, ) def ImageI(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.int32)""" kwargs["dtype"] = jnp.int32 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageF, lax_description=IMAGE_LAX_DOCS, ) def ImageF(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.float32)""" kwargs["dtype"] = jnp.float32 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageD, lax_description=IMAGE_LAX_DOCS, ) def ImageD(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.float64)""" kwargs["dtype"] = jnp.float64 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageCF, lax_description=IMAGE_LAX_DOCS, ) def ImageCF(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.complex64)""" kwargs["dtype"] = jnp.complex64 return Image(*args, **kwargs)
[docs] @implements( _galsim.ImageCD, lax_description=IMAGE_LAX_DOCS, ) def ImageCD(*args, **kwargs): """Alias for galsim.Image(..., dtype=numpy.complex128)""" kwargs["dtype"] = jnp.complex128 return Image(*args, **kwargs)
################################################################################################ # # Now we have to make some modifications to the C++ layer objects. Mostly adding some # arithmetic functions, so they work more intuitively. # def _safe_cast(array, target_isinteger, target_dtype): # code snippet pulled from upstream GalSim and turned into a general purpose # function # # Assign the given array to self.array, safely casting it to the required type. # Most important is to make sure integer types round first before casting, since # numpy's astype doesn't do any rounding. if target_isinteger: array = jnp.around(array) return array.astype(target_dtype) # Define a utility function to be used by the arithmetic functions below def check_image_consistency(im1, im2, integer=False): if integer and not im1.isinteger: raise _galsim.GalSimValueError("Image must have integer values.", im1) if isinstance(im2, Image): if im1.array.shape != im2.array.shape: raise _galsim.GalSimIncompatibleValuesError( "Image shapes are inconsistent", im1=im1, im2=im2 ) if integer and not im2.isinteger: raise _galsim.GalSimValueError("Image must have integer values.", im2) def Image_add(self, other): check_image_consistency(self, other) try: a = other.array except AttributeError: a = other return Image(self.array + a, bounds=self.bounds, wcs=self.wcs) def Image_iadd(self, other): check_image_consistency(self, other) try: a = other.array dt = a.dtype except AttributeError: a = other dt = type(a) if dt == self.array.dtype: self._array = self.array.at[...].add(a) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array + a, self.isinteger, self.array.dtype) ) return self def Image_sub(self, other): check_image_consistency(self, other) try: a = other.array except AttributeError: a = other return Image(self.array - a, bounds=self.bounds, wcs=self.wcs) def Image_rsub(self, other): return Image(other - self.array, bounds=self.bounds, wcs=self.wcs) def Image_isub(self, other): check_image_consistency(self, other) try: a = other.array dt = a.dtype except AttributeError: a = other dt = type(a) if dt == self.array.dtype: self._array = self.array.at[...].subtract(a) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array - a, self.isinteger, self.array.dtype) ) return self def Image_mul(self, other): check_image_consistency(self, other) try: a = other.array except AttributeError: a = other return Image(self.array * a, bounds=self.bounds, wcs=self.wcs) def Image_imul(self, other): check_image_consistency(self, other) try: a = other.array dt = a.dtype except AttributeError: a = other dt = type(a) if dt == self.array.dtype: self._array = self.array.at[...].multiply(a) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array * a, self.isinteger, self.array.dtype) ) return self def Image_div(self, other): check_image_consistency(self, other) try: a = other.array except AttributeError: a = other return Image(self.array / a, bounds=self.bounds, wcs=self.wcs) def Image_rdiv(self, other): return Image(other / self.array, bounds=self.bounds, wcs=self.wcs) def Image_idiv(self, other): check_image_consistency(self, other) try: a = other.array dt = a.dtype except AttributeError: a = other dt = type(a) if dt == self.array.dtype and not self.isinteger: # if dtype is an integer type, then numpy doesn't allow true division /= to assign # back to an integer array. So for integers (or mixed types), don't use /=. self._array = self.array.at[...].divide(a) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array / a, self.isinteger, self.array.dtype) ) return self def Image_floordiv(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other return Image(self.array // a, bounds=self.bounds, wcs=self.wcs) def Image_rfloordiv(self, other): check_image_consistency(self, other, integer=True) return Image(other // self.array, bounds=self.bounds, wcs=self.wcs) def Image_ifloordiv(self, other): check_image_consistency(self, other, integer=True) try: a = other.array dt = a.dtype except AttributeError: a = other dt = type(a) if dt == self.array.dtype: self._array = self.array.at[...].set(self.array // a) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array // a, self.isinteger, self.array.dtype) ) return self def Image_mod(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other return Image(self.array % a, bounds=self.bounds, wcs=self.wcs) def Image_rmod(self, other): check_image_consistency(self, other, integer=True) return Image(other % self.array, bounds=self.bounds, wcs=self.wcs) def Image_imod(self, other): check_image_consistency(self, other, integer=True) try: a = other.array dt = a.dtype except AttributeError: a = other dt = type(a) if dt == self.array.dtype: self._array = self.array.at[...].set(self.array % a) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array % a, self.isinteger, self.array.dtype) ) return self def Image_pow(self, other): return Image(self.array**other, bounds=self.bounds, wcs=self.wcs) def Image_ipow(self, other): if not isinstance(other, int) and not isinstance(other, float): raise TypeError("Can only raise an image to a float or int power!") if not self.isinteger or isinstance(other, int): self._array = self.array.at[...].power(other) else: # jax-galsim's rounding of float-to-int is platform dependent # so we explicitly round to ints if needed self._array = self.array.at[...].set( _safe_cast(self.array**other, self.isinteger, self.array.dtype) ) return self def Image_neg(self): result = self.copy() result *= -1 return result # Define &, ^ and | only for integer-type images def Image_and(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other return Image(self.array & a, bounds=self.bounds, wcs=self.wcs) def Image_iand(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other self._array = self.array.at[...].set(self.array & a) return self def Image_xor(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other return Image(self.array ^ a, bounds=self.bounds, wcs=self.wcs) def Image_ixor(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other self._array = self.array.at[...].set(self.array ^ a) return self def Image_or(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other return Image(self.array | a, bounds=self.bounds, wcs=self.wcs) def Image_ior(self, other): check_image_consistency(self, other, integer=True) try: a = other.array except AttributeError: a = other self._array = self.array.at[...].set(self.array | a) return self # inject the arithmetic operators as methods of the Image class: Image.__add__ = Image_add Image.__radd__ = Image_add Image.__iadd__ = Image_iadd Image.__sub__ = Image_sub Image.__rsub__ = Image_rsub Image.__isub__ = Image_isub Image.__mul__ = Image_mul Image.__rmul__ = Image_mul Image.__imul__ = Image_imul Image.__div__ = Image_div Image.__rdiv__ = Image_rdiv Image.__truediv__ = Image_div Image.__rtruediv__ = Image_rdiv Image.__idiv__ = Image_idiv Image.__itruediv__ = Image_idiv Image.__mod__ = Image_mod Image.__rmod__ = Image_rmod Image.__imod__ = Image_imod Image.__floordiv__ = Image_floordiv Image.__rfloordiv__ = Image_rfloordiv Image.__ifloordiv__ = Image_ifloordiv Image.__ipow__ = Image_ipow Image.__pow__ = Image_pow Image.__neg__ = Image_neg Image.__and__ = Image_and Image.__xor__ = Image_xor Image.__or__ = Image_or Image.__rand__ = Image_and Image.__rxor__ = Image_xor Image.__ror__ = Image_or Image.__iand__ = Image_iand Image.__ixor__ = Image_ixor Image.__ior__ = Image_ior