Source code for jax_galsim.bounds

import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
    cast_to_float,
    cast_to_int,
    ensure_hashable,
    has_tracers,
    implements,
)
from jax_galsim.position import Position, PositionD, PositionI

CONST_TYPES = (float, int, np.ndarray, np.int32, np.int64, np.float32, np.float64)
CONST_TYPES_WITH_JAX = CONST_TYPES + (
    jax.Array,
    jnp.array,
    jnp.int32,
    jnp.int64,
    jnp.float32,
    jnp.float64,
)

# TODO: write extra docs for JAX changes
BOUNDS_LAX_DESCR = """\
The JAX implementation

- will not always test whether the bounds are valid
- will not always test whether BoundsI is initialized with integers

Further, the JAX implementation adds a new method, ``isStatic`` to the
``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance
has been instantiated with static, known values, ``isStatic()`` will
return ``True``. You can indicate to JAX-GalSim that a ``BoundsI``
instance should be static via initializing it with the ``static``
keyword set to the ``True``. If the object detects that it is being
initialized with non-static data, an error will be raised.

``BoundsI`` objects in JAX-Galsim support an additional initialization
call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case,
the values for ``deltax/y`` indicate the width of the bounds and must be
static constants.

When calling ``jax.vmap`` over ``BoundsI`` objects, only ``x/ymin``
are vectorized over. This restriction allows for code that renders
objects in fixed sized stamps with variable locations, a common
operation. ``BoundsI`` objects which are static (i.e., ``isStatic()``
returns ``True``) are treated as constants with respect to ``vmap``,
``jit``, and other JAX transforms.
"""


[docs] @implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class Bounds: def __init__(self): raise NotImplementedError( "Cannot instantiate the base class. Use either BoundsD or BoundsI." ) def _parse_args(self, *args, **kwargs): if len(kwargs) == 0: if len(args) == 4: self._isdefined = True self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: self._isdefined = False self.xmin = 0 self.ymin = 0 self.deltax = 0 self.deltay = 0 elif len(args) == 1: if isinstance(args[0], Bounds): if isinstance(self, BoundsI) and isinstance(args[0], BoundsD): offset = 1 elif isinstance(self, BoundsD) and isinstance(args[0], BoundsI): offset = -1 else: offset = 0 self._isdefined = args[0]._isdefined self.xmin = args[0].xmin self.deltax = args[0].deltax + offset self.ymin = args[0].ymin self.deltay = args[0].deltay + offset elif isinstance(args[0], Position): self._isdefined = True self.xmin = self.xmax = args[0].x self.ymin = self.ymax = args[0].y else: raise TypeError( "Single argument to %s must be either a Bounds or a Position" % (self.__class__.__name__) ) self._isdefined = True elif len(args) == 2: if isinstance(args[0], Position) and isinstance(args[1], Position): self._isdefined = True self.xmin = min(args[0].x, args[1].x) self.xmax = max(args[0].x, args[1].x) self.ymin = min(args[0].y, args[1].y) self.ymax = max(args[0].y, args[1].y) else: raise TypeError( "Two arguments to %s must be Positions" % (self.__class__.__name__) ) else: raise TypeError( "%s takes either 1, 2, or 4 arguments (%d given)" % (self.__class__.__name__, len(args)) ) elif len(args) != 0: raise TypeError( "Cannot provide both keyword and non-keyword arguments to %s" % (self.__class__.__name__) ) else: try: self._isdefined = True self.xmin = kwargs.pop("xmin") self.ymin = kwargs.pop("ymin") except KeyError: raise TypeError( "Keyword arguments, xmin, ymin are required for %s" % (self.__class__.__name__) ) if "xmax" in kwargs and "ymax" in kwargs: self.xmax = kwargs.pop("xmax") self.ymax = kwargs.pop("ymax") elif "deltax" in kwargs and "deltay" in kwargs: self.deltax = kwargs.pop("deltax") self.deltay = kwargs.pop("deltay") else: raise TypeError( "Keyword arguments, either (xmax, ymax) " "or (deltax, deltay) are required for %s" % (self.__class__.__name__) ) if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) # for simple inputs, we can check if the bounds are valid if isinstance(self, BoundsD): max_delta = 0 else: max_delta = 1 if ( isinstance(self.deltax, CONST_TYPES) and isinstance(self.deltay, CONST_TYPES) and (self.deltax < max_delta or self.deltay < max_delta) ): self._isdefined = False
[docs] @implements(_galsim.Bounds.area) def area(self): return self._area()
[docs] @implements(_galsim.Bounds.withBorder) def withBorder(self, dx, dy=None): self._check_scalar(dx, "dx") if dy is None: dy = dx else: self._check_scalar(dy, "dy") return self.__class__( xmin=self.xmin - dx, deltax=self.deltax + 2 * dx, ymin=self.ymin - dy, deltay=self.deltay + 2 * dy, )
@property @implements(_galsim.Bounds.origin) def origin(self): return self._pos_class(self.xmin, self.ymin) @property @implements(_galsim.Bounds.center) def center(self): if not self.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "center is invalid for an undefined Bounds" ) return self._center @property @implements(_galsim.Bounds.true_center) def true_center(self): if not self.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "true_center is invalid for an undefined Bounds" ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)
[docs] @implements(_galsim.Bounds.includes) def includes(self, *args): if len(args) == 1: if isinstance(args[0], Bounds): b = args[0] return ( self.isDefined() & b.isDefined() & (self.xmin <= b.xmin) & (self.xmax >= b.xmax) & (self.ymin <= b.ymin) & (self.ymax >= b.ymax) ) elif isinstance(args[0], Position): p = args[0] return ( self.isDefined() & (self.xmin <= p.x) & (self.ymin <= p.y) & (p.x <= self.xmax) & (p.y <= self.ymax) ) else: raise TypeError("Invalid argument %s" % args[0]) elif len(args) == 2: x, y = args return ( self.isDefined() & (self.xmin <= float(x)) & (self.ymin <= float(y)) & (float(x) <= self.xmax) & (float(y) <= self.ymax) ) elif len(args) == 0: raise TypeError("include takes at least 1 argument (0 given)") else: raise TypeError("include takes at most 2 arguments (%d given)" % len(args))
[docs] @implements(_galsim.Bounds.expand) def expand(self, factor_x, factor_y=None): if factor_y is None: factor_y = factor_x dx = (self.xmax - self.xmin) * 0.5 * (factor_x - 1.0) dy = (self.ymax - self.ymin) * 0.5 * (factor_y - 1.0) if isinstance(self, BoundsI): dx = jnp.ceil(dx) dy = jnp.ceil(dy) return self.withBorder(dx, dy)
[docs] @implements(_galsim.Bounds.isDefined) def isDefined(self): return self._isdefined
[docs] @implements(_galsim.Bounds.getXMin) def getXMin(self): return self.xmin
[docs] @implements(_galsim.Bounds.getXMax) def getXMax(self): return self.xmax
[docs] @implements(_galsim.Bounds.getYMin) def getYMin(self): return self.ymin
[docs] @implements(_galsim.Bounds.getYMax) def getYMax(self): return self.ymax
[docs] @implements(_galsim.Bounds.shift) def shift(self, delta): if not isinstance(delta, self._pos_class): raise TypeError("delta must be a %s instance" % self._pos_class) return self.__class__( xmin=self.xmin + delta.x, deltax=self.deltax, ymin=self.ymin + delta.y, deltay=self.deltay, )
def __and__(self, other): if not isinstance(other, self.__class__): raise TypeError("other must be a %s instance" % self.__class__.__name__) if not self.isDefined() or not other.isDefined(): return self.__class__() else: xmin = jnp.maximum(self.xmin, other.xmin) xmax = jnp.minimum(self.xmax, other.xmax) ymin = jnp.maximum(self.ymin, other.ymin) ymax = jnp.minimum(self.ymax, other.ymax) if xmin > xmax or ymin > ymax: return self.__class__() else: return self.__class__(xmin, xmax, ymin, ymax) def __add__(self, other): if isinstance(other, self.__class__): if not other.isDefined(): return self elif self.isDefined(): xmin = jnp.minimum(self.xmin, other.xmin) xmax = jnp.maximum(self.xmax, other.xmax) ymin = jnp.minimum(self.ymin, other.ymin) ymax = jnp.maximum(self.ymax, other.ymax) return self.__class__(xmin, xmax, ymin, ymax) else: return other elif isinstance(other, self._pos_class): if self.isDefined(): xmin = jnp.minimum(self.xmin, other.x) xmax = jnp.maximum(self.xmax, other.x) ymin = jnp.minimum(self.ymin, other.y) ymax = jnp.maximum(self.ymax, other.y) return self.__class__(xmin, xmax, ymin, ymax) else: return self.__class__(other) else: raise TypeError( "other must be either a %s or a %s" % (self.__class__.__name__, self._pos_class.__name__) ) def _getinitargs(self): if self.isDefined(): return (self.xmin, self.xmax, self.ymin, self.ymax) else: return () def __eq__(self, other): return self is other or ( isinstance(other, self.__class__) and self._getinitargs() == other._getinitargs() ) def __ne__(self, other): return not self.__eq__(other) def __hash__(self): return hash( ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.deltax), ensure_hashable(self.ymin), ensure_hashable(self.deltay), ) )
[docs] def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing if self.isDefined(): children = (self.xmin, self.deltax, self.ymin, self.deltay) else: children = tuple() # Define auxiliary static data that doesn’t need to be traced aux_data = None return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" if children: return cls( xmin=children[0], deltax=children[1], ymin=children[2], deltay=children[3], ) else: return cls()
[docs] @classmethod def from_galsim(cls, galsim_bounds): """Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object.""" if isinstance(galsim_bounds, _galsim.BoundsD): _cls = BoundsD kwargs = {} elif isinstance(galsim_bounds, _galsim.BoundsI): _cls = BoundsI kwargs = {"static": True} else: raise TypeError( "galsim_bounds must be either a %s or a %s" % (_galsim.BoundsD.__name__, _galsim.BoundsI.__name__) ) if galsim_bounds.isDefined(): return _cls( galsim_bounds.xmin, galsim_bounds.xmax, galsim_bounds.ymin, galsim_bounds.ymax, **kwargs, ) else: return _cls()
[docs] def to_galsim(self): """Create a galsim `BoundsD/I` from a `jax_galsim.BoundsD/I` object.""" if isinstance(self, BoundsI): gs_class = _galsim.bounds.BoundsI cast = int else: gs_class = _galsim.bounds.BoundsD cast = float if self.isDefined(): return gs_class( cast(self.xmin), cast(self.xmax), cast(self.ymin), cast(self.ymax), ) else: return gs_class()
[docs] def isStatic(self): """Returns ``True`` if the ``BoundsI`` instance has static, known dimensions and location. Always returns ``False`` for ``BoundsD``.""" return self._isstatic
[docs] @implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsD(Bounds): _pos_class = PositionD def __init__(self, *args, **kwargs): self._isstatic = False self._parse_args(*args, **kwargs) self.xmin = cast_to_float(self.xmin) self.deltax = cast_to_float(self.deltax) self.ymin = cast_to_float(self.ymin) self.deltay = cast_to_float(self.deltay) def _check_scalar(self, x, name): try: if ( isinstance(x, jax.Array) and x.shape == () and x.dtype.name in ["float32", "float64", "float"] ): return elif x == float(x): return except (TypeError, ValueError): pass raise TypeError("%s must be a float value" % name) @property def xmax(self): return self.xmin + self.deltax @xmax.setter def xmax(self, value): self.deltax = value - self.xmin @property def ymax(self): return self.ymin + self.deltay @ymax.setter def ymax(self, value): self.deltay = value - self.ymin def _area(self): return self.deltax * self.deltay @property def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) def __repr__(self): if self.isDefined(): return "galsim.%s(%r, %r, %r, %r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.xmax), ensure_hashable(self.ymin), ensure_hashable(self.ymax), ) else: return "galsim.%s()" % (self.__class__.__name__) def __str__(self): if self.isDefined(): return "galsim.%s(%s,%s,%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.xmax), ensure_hashable(self.ymin), ensure_hashable(self.ymax), ) else: return "galsim.%s()" % (self.__class__.__name__) def __hash__(self): return hash( ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.deltax), ensure_hashable(self.ymin), ensure_hashable(self.deltay), ) )
[docs] @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): # initial setting to let stuff pass through freely self._isstatic = True force_static = kwargs.pop("static", False) self._parse_args(*args, **kwargs) if has_tracers(self.deltax) or has_tracers(self.deltay): raise RuntimeError( "Jax-GalSim BoundsI instances must have a fixed width! " f"Got deltax,deltay = {self.deltax!r},{self.deltay!r}." ) self.deltax = int(cast_to_int(self.deltax)) self.deltay = int(cast_to_int(self.deltay)) if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): raise TypeError("BoundsI must be initialized with integer values") if self.deltax < 1 and self.deltay < 1: self._isdefined = False # for simple inputs, we can check if the bounds are valid ints if isinstance(self._xmin, CONST_TYPES) and self._xmin != int(self._xmin): raise TypeError("BoundsI must be initialized with integer values") if isinstance(self._ymin, CONST_TYPES) and self._ymin != int(self._ymin): raise TypeError("BoundsI must be initialized with integer values") if not has_tracers(self._xmin) and not has_tracers(self._ymin): self._isstatic = True self._xmin = int(np.trunc(self._xmin)) self._ymin = int(np.trunc(self._ymin)) else: self._isstatic = False self._xmin = cast_to_float(jnp.trunc(self._xmin)) self._ymin = cast_to_float(jnp.trunc(self._ymin)) if force_static and not self._isstatic: raise RuntimeError( "BoundsI initialized with non-static " f"data (xmin,ymin = {self._xmin},{self._yminb}) " "when static data was explicitly requested." ) def _check_scalar(self, x, name): try: if ( isinstance(x, jax.Array) and x.shape == () and x.dtype.name in ["int32", "int64", "int"] ): return elif x == int(x): return except (TypeError, ValueError): pass raise TypeError("%s must be an integer value" % name)
[docs] def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." if self.isDefined(): return self.deltay, self.deltax else: return 0, 0
@property def xmin(self): if self._isstatic: return self._xmin else: return jnp.astype(self._xmin, jnp.int_) @xmin.setter def xmin(self, value): if self._isstatic: self._xmin = value else: self._xmin = jnp.astype(value, jnp.float_) @property def xmax(self): return self.xmin + self.deltax - 1 @xmax.setter def xmax(self, value): self.deltax = value - self.xmin + 1 @property def ymin(self): if self._isstatic: return self._ymin else: return jnp.astype(self._ymin, jnp.int_) @ymin.setter def ymin(self, value): if self._isstatic: self._ymin = value else: self._ymin = jnp.astype(value, jnp.float_) @property def ymax(self): return self.ymin + self.deltay - 1 @ymax.setter def ymax(self, value): self.deltay = value - self.ymin + 1 def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. if not self.isDefined(): return 0 else: return self.deltax * self.deltay @property def _center(self): # Write it this way to make sure the integer rounding goes the same way regardless # of whether the values are positive or negative. # e.g. (1,10,1,10) -> (6,6) # (-10,-1,-10,-1) -> (-5,-5) # Just up and to the right of the true center in both cases. return PositionI( self.xmin + self.deltax // 2, self.ymin + self.deltay // 2, )
[docs] def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing if self.isDefined(): if self._isstatic: # Define the children nodes of the PyTree that need tracing children = tuple() # Define auxiliary static data that doesn’t need to be traced aux_data = { "xmin": self._xmin, "ymin": self._ymin, "deltax": self.deltax, "deltay": self.deltay, } else: children = (self._xmin, self._ymin) # Define auxiliary static data that doesn’t need to be traced aux_data = {"deltax": self.deltax, "deltay": self.deltay} else: children = tuple() aux_data = None return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" if aux_data is not None: ret = cls.__new__(cls) if "xmin" in aux_data and "ymin" in aux_data: ret._isstatic = True ret._xmin = aux_data["xmin"] ret._ymin = aux_data["ymin"] else: ret._isstatic = False ret._xmin = children[0] ret._ymin = children[1] ret.deltax = aux_data["deltax"] ret.deltay = aux_data["deltay"] if ret.deltax < 1 and ret.deltay < 1: ret._isdefined = False else: ret._isdefined = True else: ret = cls() return ret
def __repr__(self): if self.isDefined(): return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.deltax), ensure_hashable(self.ymin), ensure_hashable(self.deltay), ) else: return "galsim.%s()" % (self.__class__.__name__) def __str__(self): if self.isDefined(): return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.deltax), ensure_hashable(self.ymin), ensure_hashable(self.deltay), ) else: return "galsim.%s()" % (self.__class__.__name__) def _getinitargs(self): if self.isDefined(): return (self.xmin, self.deltax, self.ymin, self.deltay) else: return () def __eq__(self, other): return self is other or ( isinstance(other, BoundsI) and self._getinitargs() == other._getinitargs() ) def __hash__(self): return hash( ( self.__class__.__name__, ensure_hashable(self.xmin), ensure_hashable(self.deltax), ensure_hashable(self.ymin), ensure_hashable(self.deltay), ) )