import equinox
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax_galsim.core.utils import (
cast_to_float,
cast_to_int,
check_is_int_then_cast,
ensure_hashable,
implements,
)
from jax_galsim.position import Position, PositionD, PositionI
BOUNDS_LAX_DESCR = """\
The JAX-GalSim implementation of the ``BoundsI/D`` classes have some key differences
from GalSim.
- ``BoundsI`` instances must have statically known shapes, but may have non-static
start locations (i.e., ``xmin`` and ``ymin`` may be JAX arrays, traced in JIT operations, etc.).
This restriction mirrors the JAX restriction that arrays have fixed shapes when traced
for function transformations like ``jax.vmap``, ``jax.jit``, etc.
- Upon initialization, if a ``BoundsI`` object has a non-static shape, JAX-GalSim will attempt to convert
it to a static shape by extracting the dimensions from the array via ``.item()``. This operation will
cause JAX to raise an error if the code is being traced. JAX-Galsim performs the same conversion operation
when the ``deltax`` or ``deltay`` properties are set to non-static values via assignment.
- If a ``BoundsI`` object is declared with static ``xmin`` and ``ymin`` values, and then one attempts to
convert them to non-static values via assignment, JAX-GalSim will attempt to convert the assigned values
back to static values. This operation will raise an error if the code is being traced.
- ``Bounds`` classes in JAX-GalSim have an extra method, ``isStatic`` that returns ``True`` if the object
was instantiated with static ``xmin`` and ``ymin`` values. This method always returns ``False`` for
``BoundsD`` objects.
- JAX-GalSim does not support the use of the `&` and `+` operators (i.e., the dunder methods ``__and__``
and ``__add__`` ) with ``BoundsI`` objects when tracing code.
- JAX-Galsim supports an additional initialization signature ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``
to help users specify the widths ``deltax`` and ``deltay`` statically at initialization.
- When calling ``jax.vmap``, ``jax.jit`` etc. with ``BoundsI`` objects, ``xmin`` and ``ymin`` are
traced by JAX. The combination of this feature with statically known shapes allows for code that renders
objects in fixed sized stamps with variable locations, a common operation.
- For ``BoundsD``, all ``x(y)min(max)`` values are traced as arrays.
- ``Bounds`` objects always return a JAX boolean values for various method calls, except for
``BoundsI.isDefined()`` which is always a Python boolean value.
"""
[docs]
@implements(_galsim.Bounds, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class Bounds:
def __init__(self):
raise NotImplementedError(
"Cannot instantiate the base class. Use either BoundsD or BoundsI."
)
def _parse_args(self, *args, **kwargs):
do_isdefined = True
if len(kwargs) == 0:
if len(args) == 4:
self.xmin, self.xmax, self.ymin, self.ymax = args
elif len(args) == 0:
do_isdefined = False
self._isdefined = False
self.xmin = 0
self.ymin = 0
self.deltax = 0
self.deltay = 0
elif len(args) == 1:
if isinstance(args[0], Bounds):
if isinstance(self, BoundsI) and isinstance(args[0], BoundsD):
offset = 1
elif isinstance(self, BoundsD) and isinstance(args[0], BoundsI):
offset = -1
else:
offset = 0
self._isdefined = args[0]._isdefined
self.xmin = args[0].xmin
self.deltax = args[0].deltax + offset
self.ymin = args[0].ymin
self.deltay = args[0].deltay + offset
do_isdefined = False
elif isinstance(args[0], Position):
self.xmin = self.xmax = args[0].x
self.ymin = self.ymax = args[0].y
else:
raise TypeError(
"Single argument to %s must be either a Bounds or a Position"
% (self.__class__.__name__)
)
elif len(args) == 2:
if isinstance(args[0], Position) and isinstance(args[1], Position):
self.xmin = jnp.minimum(args[0].x, args[1].x)
self.xmax = jnp.maximum(args[0].x, args[1].x)
self.ymin = jnp.minimum(args[0].y, args[1].y)
self.ymax = jnp.maximum(args[0].y, args[1].y)
else:
raise TypeError(
"Two arguments to %s must be Positions"
% (self.__class__.__name__)
)
else:
raise TypeError(
"%s takes either 1, 2, or 4 arguments (%d given)"
% (self.__class__.__name__, len(args))
)
elif len(args) != 0:
raise TypeError(
"Cannot provide both keyword and non-keyword arguments to %s"
% (self.__class__.__name__)
)
else:
try:
self.xmin = kwargs.pop("xmin")
self.ymin = kwargs.pop("ymin")
except KeyError:
raise TypeError(
"Keyword arguments, xmin, ymin are required for %s"
% (self.__class__.__name__)
)
if "xmax" in kwargs and "ymax" in kwargs:
self.xmax = kwargs.pop("xmax")
self.ymax = kwargs.pop("ymax")
elif "deltax" in kwargs and "deltay" in kwargs:
self.deltax = kwargs.pop("deltax")
self.deltay = kwargs.pop("deltay")
else:
raise TypeError(
"Keyword arguments, either (xmax, ymax) "
"or (deltax, deltay) are required for %s"
% (self.__class__.__name__)
)
if kwargs:
raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys())
return do_isdefined
[docs]
@implements(_galsim.Bounds.area)
def area(self):
return self._area()
[docs]
@implements(_galsim.Bounds.withBorder)
def withBorder(self, dx, dy=None):
self._check_scalar(dx, "dx")
if dy is None:
dy = dx
else:
self._check_scalar(dy, "dy")
return self.__class__(
xmin=self.xmin - dx,
deltax=self.deltax + 2 * dx,
ymin=self.ymin - dy,
deltay=self.deltay + 2 * dy,
)
@property
@implements(_galsim.Bounds.origin)
def origin(self):
return self._pos_class(self.xmin, self.ymin)
@property
@implements(_galsim.Bounds.center)
def center(self):
if isinstance(self._isdefined, bool):
if not self._isdefined:
raise _galsim.GalSimUndefinedBoundsError(
"center is invalid for an undefined Bounds"
)
else:
self._isdefined = equinox.error_if(
self._isdefined,
jnp.any(~self._isdefined),
"center is invalid for an undefined Bounds",
)
return self._center
@property
@implements(_galsim.Bounds.true_center)
def true_center(self):
if isinstance(self._isdefined, bool):
if not self._isdefined:
raise _galsim.GalSimUndefinedBoundsError(
"true_center is invalid for an undefined Bounds"
)
else:
self._isdefined = equinox.error_if(
self._isdefined,
jnp.any(~self._isdefined),
"true_center is invalid for an undefined Bounds",
)
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)
[docs]
@implements(_galsim.Bounds.includes)
def includes(self, *args):
if len(args) == 1:
if isinstance(args[0], Bounds):
b = args[0]
return (
jnp.array(self.isDefined())
& jnp.array(b.isDefined())
& jnp.array(self.xmin <= b.xmin)
& jnp.array(self.xmax >= b.xmax)
& jnp.array(self.ymin <= b.ymin)
& jnp.array(self.ymax >= b.ymax)
)
elif isinstance(args[0], Position):
p = args[0]
return (
jnp.array(self.isDefined())
& jnp.array(self.xmin <= p.x)
& jnp.array(self.ymin <= p.y)
& jnp.array(p.x <= self.xmax)
& jnp.array(p.y <= self.ymax)
)
else:
raise TypeError("Invalid argument %s" % args[0])
elif len(args) == 2:
x, y = args
x = cast_to_float(x)
y = cast_to_float(y)
return (
jnp.array(self.isDefined())
& jnp.array(self.xmin <= x)
& jnp.array(self.ymin <= y)
& jnp.array(x <= self.xmax)
& jnp.array(y <= self.ymax)
)
elif len(args) == 0:
raise TypeError("include takes at least 1 argument (0 given)")
else:
raise TypeError("include takes at most 2 arguments (%d given)" % len(args))
[docs]
@implements(_galsim.Bounds.expand)
def expand(self, factor_x, factor_y=None):
if factor_y is None:
factor_y = factor_x
dx = (self.xmax - self.xmin) * 0.5 * (factor_x - 1.0)
dy = (self.ymax - self.ymin) * 0.5 * (factor_y - 1.0)
if isinstance(self, BoundsI):
dx = jnp.ceil(dx)
dy = jnp.ceil(dy)
return self.withBorder(dx, dy)
[docs]
@implements(_galsim.Bounds.isDefined)
def isDefined(self):
return self._isdefined
[docs]
@implements(_galsim.Bounds.getXMin)
def getXMin(self):
return self.xmin
[docs]
@implements(_galsim.Bounds.getXMax)
def getXMax(self):
return self.xmax
[docs]
@implements(_galsim.Bounds.getYMin)
def getYMin(self):
return self.ymin
[docs]
@implements(_galsim.Bounds.getYMax)
def getYMax(self):
return self.ymax
[docs]
@implements(_galsim.Bounds.shift)
def shift(self, delta):
if not isinstance(delta, self._pos_class):
raise TypeError("delta must be a %s instance" % self._pos_class)
return self.__class__(
xmin=self.xmin + delta.x,
deltax=self.deltax,
ymin=self.ymin + delta.y,
deltay=self.deltay,
)
def __and__(self, other):
if not isinstance(other, self.__class__):
raise TypeError("other must be a %s instance" % self.__class__.__name__)
return _bounds_and_op_dynamic(self, other)
def __add__(self, other):
if isinstance(other, self.__class__):
return _bounds_bounds_add_op_dynamic(self, other)
elif isinstance(other, self._pos_class):
return _bounds_pos_add_op_dynamic(self, other)
else:
raise TypeError(
"other must be either a %s or a %s"
% (self.__class__.__name__, self._pos_class.__name__)
)
def __eq__(self, other):
raise NotImplementedError(
"The `__eq__` magic method must be implemented by subclasses of `Bounds`."
)
def __ne__(self, other):
raise NotImplementedError(
"The `__ne__` magic method must be implemented by subclasses of `Bounds`."
)
def __hash__(self):
return hash(
(
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.deltax),
ensure_hashable(self.ymin),
ensure_hashable(self.deltay),
)
)
[docs]
def tree_flatten(self):
"""This function flattens the Bounds into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined)
# Define auxiliary static data that doesn’t need to be traced
aux_data = {"isstatic": self._isstatic}
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
ret = cls.__new__(cls)
ret.xmin = children[0]
ret.deltax = children[1]
ret.ymin = children[2]
ret.deltay = children[3]
ret._isdefined = children[4]
ret._isstatic = aux_data["isstatic"]
return ret
[docs]
@classmethod
def from_galsim(cls, galsim_bounds):
"""Create a jax_galsim `BoundsD/I` from a `galsim.BoundsD/I` object."""
if isinstance(galsim_bounds, _galsim.BoundsD):
_cls = BoundsD
elif isinstance(galsim_bounds, _galsim.BoundsI):
_cls = BoundsI
else:
raise TypeError(
"galsim_bounds must be either a %s or a %s"
% (_galsim.BoundsD.__name__, _galsim.BoundsI.__name__)
)
if galsim_bounds.isDefined():
return _cls(
galsim_bounds.xmin,
galsim_bounds.xmax,
galsim_bounds.ymin,
galsim_bounds.ymax,
)
else:
return _cls()
[docs]
def to_galsim(self):
"""Create a galsim `BoundsD/I` from a `jax_galsim.BoundsD/I` object."""
if isinstance(self, BoundsI):
gs_class = _galsim.bounds.BoundsI
cast = int
else:
gs_class = _galsim.bounds.BoundsD
cast = float
if self.isDefined():
return gs_class(
cast(self.xmin),
cast(self.xmax),
cast(self.ymin),
cast(self.ymax),
)
else:
return gs_class()
[docs]
def isStatic(self):
"""Returns ``True`` if the ``BoundsI`` instance
has static, known dimensions and location. Always returns
``False`` for ``BoundsD``."""
return self._isstatic
def _bounds_and_op_dynamic(self, other):
xmin = jnp.maximum(self.xmin, other.xmin)
xmax = jnp.minimum(self.xmax, other.xmax)
ymin = jnp.maximum(self.ymin, other.ymin)
ymax = jnp.minimum(self.ymax, other.ymax)
is_defined = (
jnp.array(self.isDefined())
& jnp.array(other.isDefined())
& jnp.array(ymin <= ymax)
& jnp.array(xmin <= xmax)
)
xmin = jnp.where(
is_defined,
xmin,
0.0,
)
xmax = jnp.where(
is_defined,
xmax,
0.0,
)
ymin = jnp.where(
is_defined,
ymin,
0.0,
)
ymax = jnp.where(
is_defined,
ymax,
0.0,
)
cls = self.__class__
if isinstance(self, BoundsI):
# we use the class constructor here to ensure we properly convert
# bounds shape to static ints
ret = cls(
xmin=xmin,
deltax=xmax - xmin + 1,
ymin=ymin,
deltay=ymax - ymin + 1,
)
# we have to do a conversion to static bools here too
with jax.ensure_compile_time_eval():
ret._isdefined = bool(is_defined.item())
else:
ret = cls.__new__(cls)
ret.xmin = xmin
ret.deltax = xmax - xmin
ret.ymin = ymin
ret.deltay = ymax - ymin
ret._isdefined = is_defined
ret._isstatic = False
return ret
def _bounds_bounds_add_op_dynamic(self, other):
def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op):
return jnp.where(
~jnp.array(other_isdef),
self_attr,
jnp.where(jnp.array(self_isdef), op(self_attr, other_attr), other_attr),
)
xmin = _ret_correct_attr(
self._isdefined, self.xmin, other._isdefined, other.xmin, jnp.minimum
)
xmax = _ret_correct_attr(
self._isdefined, self.xmax, other._isdefined, other.xmax, jnp.maximum
)
ymin = _ret_correct_attr(
self._isdefined, self.ymin, other._isdefined, other.ymin, jnp.minimum
)
ymax = _ret_correct_attr(
self._isdefined, self.ymax, other._isdefined, other.ymax, jnp.maximum
)
cls = self.__class__
if isinstance(self, BoundsI):
# we use the class constructor here to ensure we properly convert
# bounds shape to static ints
ret = cls(
xmin=xmin,
deltax=xmax - xmin + 1,
ymin=ymin,
deltay=ymax - ymin + 1,
)
is_defined = jnp.where(
~jnp.array(other._isdefined),
jnp.array(self._isdefined),
jnp.where(
jnp.array(self._isdefined),
jnp.array(ret.deltax >= 1) & jnp.array(ret.deltay >= 1),
jnp.array(other._isdefined),
),
)
# we have to do a conversion to static bools here too
with jax.ensure_compile_time_eval():
ret._isdefined = bool(is_defined.item())
else:
ret = cls.__new__(cls)
ret.xmin = xmin
ret.deltax = xmax - xmin
ret.ymin = ymin
ret.deltay = ymax - ymin
ret._isdefined = jnp.where(
~jnp.array(other._isdefined),
jnp.array(self._isdefined),
jnp.where(
jnp.array(self._isdefined),
jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0),
jnp.array(other._isdefined),
),
)
ret._isstatic = False
return ret
def _bounds_pos_add_op_dynamic(self, other):
xmin = jnp.where(
self._isdefined,
jnp.minimum(self.xmin, other.x),
other.x,
)
xmax = jnp.where(
self._isdefined,
jnp.maximum(self.xmax, other.x),
other.x,
)
ymin = jnp.where(
self._isdefined,
jnp.minimum(self.ymin, other.y),
other.y,
)
ymax = jnp.where(
self._isdefined,
jnp.maximum(self.ymax, other.y),
other.y,
)
cls = self.__class__
if isinstance(self, BoundsI):
# we use the class constructor here to ensure we properly convert
# bounds shape to static ints
ret = cls(
xmin=xmin,
deltax=xmax - xmin + 1,
ymin=ymin,
deltay=ymax - ymin + 1,
)
is_defined = jnp.where(
jnp.array(self._isdefined),
jnp.array(ret.deltax >= 1) & jnp.array(ret.deltay >= 1),
jnp.array(True),
)
# we have to do a conversion to static bools here too
with jax.ensure_compile_time_eval():
ret._isdefined = bool(is_defined.item())
else:
ret = cls.__new__(cls)
ret.xmin = xmin
ret.deltax = xmax - xmin
ret.ymin = ymin
ret.deltay = ymax - ymin
ret._isdefined = jnp.where(
self._isdefined,
jnp.array(ret.deltax >= 0) & jnp.array(ret.deltay >= 0),
jnp.array(True),
)
ret._isstatic = False
return ret
[docs]
@implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsD(Bounds):
_pos_class = PositionD
def __init__(self, *args, **kwargs):
do_isdefined = self._parse_args(*args, **kwargs)
self.xmin = cast_to_float(jnp.array(self.xmin))
self.deltax = cast_to_float(jnp.array(self.deltax))
self.ymin = cast_to_float(jnp.array(self.ymin))
self.deltay = cast_to_float(jnp.array(self.deltay))
if do_isdefined:
self._isdefined = (self.deltax >= 0) & (self.deltay >= 0)
self._isdefined = jnp.array(self._isdefined)
self._isstatic = False
def _check_scalar(self, x, name):
try:
if (
isinstance(x, (jax.Array, jnp.ndarray, np.ndarray))
and x.shape == ()
and jnp.issubdtype(jnp.array(x).dtype, jnp.floating)
):
return
elif x == float(x):
return
except (TypeError, ValueError):
pass
raise TypeError("%s must be a float value" % name)
@property
def xmax(self):
return self.xmin + self.deltax
@xmax.setter
def xmax(self, value):
self.deltax = value - self.xmin
@property
def ymax(self):
return self.ymin + self.deltay
@ymax.setter
def ymax(self, value):
self.deltay = value - self.ymin
def _area(self):
return self.deltax * self.deltay
@property
def _center(self):
return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0)
def __repr__(self):
# sometimes we will encounter a tracer here
# and so we suppress any boolean conversion errors
try:
if jnp.any(self.isDefined()):
print_full = True
else:
print_full = False
except Exception:
print_full = True
if print_full:
return "galsim.%s(%r, %r, %r, %r)" % (
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.xmax),
ensure_hashable(self.ymin),
ensure_hashable(self.ymax),
)
else:
return "galsim.%s()" % (self.__class__.__name__)
def __str__(self):
# sometimes we will encounter a tracer here
# and so we suppress any boolean conversion errors
try:
if jnp.any(self.isDefined()):
print_full = True
else:
print_full = False
except Exception:
print_full = True
if print_full:
return "galsim.%s(%s,%s,%s,%s)" % (
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.xmax),
ensure_hashable(self.ymin),
ensure_hashable(self.ymax),
)
else:
return "galsim.%s()" % (self.__class__.__name__)
def __eq__(self, other):
if self is other:
return jnp.array(True)
elif isinstance(other, self.__class__):
self_isdef = jnp.array(self.isDefined())
other_isdef = jnp.array(other.isDefined())
return (
self_isdef
& other_isdef
& jnp.array(self.xmin == other.xmin)
& jnp.array(self.ymin == other.ymin)
& jnp.array(self.xmax == other.xmax)
& jnp.array(self.ymax == other.ymax)
) | ((~self_isdef) & (~other_isdef))
else:
return jnp.array(False)
def __ne__(self, other):
return ~self.__eq__(other)
def __hash__(self):
return hash(
(
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.deltax),
ensure_hashable(self.ymin),
ensure_hashable(self.deltay),
)
)
def _getinitargs(self):
# defined only for galsim test suite
return (self.xmin, self.xmax, self.ymin, self.ymax)
[docs]
@implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR)
@register_pytree_node_class
class BoundsI(Bounds):
_pos_class = PositionI
def __init__(self, *args, **kwargs):
# we set these variables to disable array to python int
# or python int to array conversions for xmin/ymin while we
# initialize the object.
# the setter methods below validate that the inputs are ints,
# so we skip that in the init.
# the class always converts deltax/deltay to python ints and
# an error will be raised if that cannot be done.
self._isstatic = True
self._dotypeconversion = False
self._parse_args(*args, **kwargs)
# now we compute these properties correctly and turn on type conversion
self._isdefined = self.deltax >= 1 and self.deltay >= 1
self._isstatic = isinstance(self._xmin, int) and isinstance(self._ymin, int)
self._dotypeconversion = True
def _check_scalar(self, x, name):
try:
if (
isinstance(x, (jax.Array, jnp.ndarray, np.ndarray))
and x.shape == ()
and jnp.issubdtype(jnp.array(x).dtype, jnp.integer)
):
return
elif x == int(x):
return
except (TypeError, ValueError):
pass
raise TypeError("%s must be an integer value" % name)
[docs]
def numpyShape(self):
"A simple utility function to get the numpy shape that corresponds to this `Bounds` object."
if self._isdefined:
return self.deltay, self.deltax
else:
return 0, 0
@property
def xmin(self):
# for non-static bounds we store xmin internally as a float even
# though it is an int so that autodiff works properly (needs floats in general).
# thus we cast here.
return cast_to_int(self._xmin)
@xmin.setter
def xmin(self, value):
value = check_is_int_then_cast(value, "BoundsI xmin values must be integers")
if self._isstatic:
if self._dotypeconversion:
# attempt to convert to static values
# this will raise if values are being traced
# we let that error propagate instead of reraising
# our own.
with jax.ensure_compile_time_eval():
if not isinstance(value, int):
value = int(value.item())
self._xmin = value
else:
self._xmin = jnp.astype(value, float)
@property
def deltax(self):
return self._deltax
@deltax.setter
def deltax(self, value):
value = check_is_int_then_cast(
value, "BoundsI deltax must be set to an integer value"
)
# attempt to convert widths to static values
# this will raise if values are being traced
# we let that error propagate instead of reraising
# our own.
with jax.ensure_compile_time_eval():
if not isinstance(value, int):
value = int(value.item())
self._deltax = value
@property
def xmax(self):
return cast_to_int(self.xmin + self.deltax - 1)
@xmax.setter
def xmax(self, value):
value = check_is_int_then_cast(
value, "BoundsI xmax must be set to an integer value"
)
self.deltax = value - self.xmin + 1
@property
def ymin(self):
# for non-static bounds we store ymin internally as a float even
# though it is an int so that autodiff works properly (needs floats in general).
# thus we cast here.
return cast_to_int(self._ymin)
@ymin.setter
def ymin(self, value):
value = check_is_int_then_cast(value, "BoundsI ymin values must be integers")
if self._isstatic:
if self._dotypeconversion:
# attempt to convert to static values
# this will raise if values are being traced
# we let that error propagate instead of reraising
# our own.
with jax.ensure_compile_time_eval():
if not isinstance(value, int):
value = int(value.item())
self._ymin = value
else:
self._ymin = jnp.astype(value, float)
@property
def deltay(self):
return self._deltay
@deltay.setter
def deltay(self, value):
value = check_is_int_then_cast(
value, "BoundsI deltay must be set to an integer value"
)
# attempt to convert widths to static values
# this will raise if values are being traced
# we let that error propagate instead of reraising
# our own.
with jax.ensure_compile_time_eval():
if not isinstance(value, int):
value = int(value.item())
self._deltay = value
@property
def ymax(self):
return cast_to_int(self.ymin + self.deltay - 1)
@ymax.setter
def ymax(self, value):
value = check_is_int_then_cast(
value, "BoundsI ymax must be set to an integer value"
)
self.deltay = value - self.ymin + 1
def _area(self):
# Remember the + 1 this time to include the pixels on both edges of the bounds.
if not self._isdefined:
return 0
else:
return self.deltax * self.deltay
@property
def _center(self):
# Write it this way to make sure the integer rounding goes the same way regardless
# of whether the values are positive or negative.
# e.g. (1,10,1,10) -> (6,6)
# (-10,-1,-10,-1) -> (-5,-5)
# Just up and to the right of the true center in both cases.
return PositionI(
self.xmin + self.deltax // 2,
self.ymin + self.deltay // 2,
)
def __repr__(self):
if self._isdefined:
return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % (
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.deltax),
ensure_hashable(self.ymin),
ensure_hashable(self.deltay),
)
else:
return "galsim.%s()" % (self.__class__.__name__)
def __str__(self):
if self._isdefined:
return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % (
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.deltax),
ensure_hashable(self.ymin),
ensure_hashable(self.deltay),
)
else:
return "galsim.%s()" % (self.__class__.__name__)
def __eq__(self, other):
if self is other:
if self._isstatic:
return True
else:
return jnp.array(True)
elif isinstance(other, self.__class__):
if self._isstatic and other._isstatic:
return (
self._isdefined
and other._isdefined
and self.xmin == other.xmin
and self.ymin == other.ymin
and self.deltax == other.deltax
and self.deltay == other.deltay
) or ((not self._isdefined) and (not other._isdefined))
else:
self_isdef = jnp.array(self.isDefined())
other_isdef = jnp.array(other.isDefined())
return (
self_isdef
& other_isdef
& jnp.array(self.xmin == other.xmin)
& jnp.array(self.ymin == other.ymin)
& jnp.array(self.deltax == other.deltax)
& jnp.array(self.deltay == other.deltay)
) | ((~self_isdef) & (~other_isdef))
else:
if self._isstatic:
return False
else:
return jnp.array(False)
def __ne__(self, other):
eqval = self.__eq__(other)
if isinstance(eqval, bool):
return not eqval
else:
return ~eqval
def __hash__(self):
return hash(
(
self.__class__.__name__,
ensure_hashable(self.xmin),
ensure_hashable(self.deltax),
ensure_hashable(self.ymin),
ensure_hashable(self.deltay),
)
)
[docs]
def tree_flatten(self):
"""This function flattens the Bounds into a list of children
nodes that will be traced by JAX and auxiliary static data."""
aux_data = {
"isstatic": self._isstatic,
"dotypeconversion": self._dotypeconversion,
}
# Define the children nodes of the PyTree that need tracing
if self._isstatic:
children = tuple()
aux_data["xmin"] = self._xmin
aux_data["ymin"] = self._ymin
else:
children = (self._xmin, self._ymin)
# untraced aux data
aux_data["deltax"] = self.deltax
aux_data["deltay"] = self.deltay
aux_data["isdefined"] = self._isdefined
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
ret = cls.__new__(cls)
if aux_data["isstatic"]:
ret._xmin = aux_data["xmin"]
ret._ymin = aux_data["ymin"]
else:
ret._xmin = children[0]
ret._ymin = children[1]
ret.deltax = aux_data["deltax"]
ret.deltay = aux_data["deltay"]
ret._isdefined = aux_data["isdefined"]
ret._isstatic = aux_data["isstatic"]
ret._dotypeconversion = aux_data["dotypeconversion"]
return ret
@implements(
_galsim._BoundsD,
lax_description="JAX-GalSim doesn't skip sanity checks for ``_BoundsD``.",
)
def _BoundsD(xmin, xmax, ymin, ymax):
return BoundsD(xmin, xmax, ymin, ymax)
@implements(
_galsim._BoundsI,
lax_description="JAX-GalSim doesn't skip sanity checks for ``_BoundsI``.",
)
def _BoundsI(xmin, xmax, ymin, ymax):
return BoundsI(xmin, xmax, ymin, ymax)