Source code for jax_galsim.position

import galsim as _galsim
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
    cast_to_float,
    cast_to_int,
    ensure_hashable,
    implements,
)


[docs] @implements(_galsim.Position) class Position(object): def __init__(self): raise NotImplementedError( "Cannot instantiate the base class. Use either PositionD or PositionI." ) def _parse_args(self, *args, **kwargs): if len(kwargs) == 0: if len(args) == 2: self.x, self.y = args elif len(args) == 0: self.x = self.y = 0 elif len(args) == 1: if isinstance(args[0], (Position,)): self.x = args[0].x self.y = args[0].y else: try: self.x, self.y = args[0] except (TypeError, ValueError): raise TypeError( "Single argument to %s must be either a Position " "or a tuple." % self.__class__ ) else: raise TypeError( "%s takes at most 2 arguments (%d given)" % (self.__class__, len(args)) ) elif len(args) != 0: raise TypeError( "%s takes x and y as either named or unnamed arguments (given %s, %s)" % (self.__class__, args, kwargs) ) else: try: self.x = kwargs.pop("x") self.y = kwargs.pop("y") except KeyError: raise TypeError( "Keyword arguments x,y are required for %s" % self.__class__ ) if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) @property def _array(self): return jnp.array([self.x, self.y]) def __mul__(self, other): self._check_scalar(other, "multiply") return self.__class__(self.x * other, self.y * other) def __rmul__(self, other): return self.__mul__(other) def __div__(self, other): self._check_scalar(other, "divide") return self.__class__(self.x / other, self.y / other) __truediv__ = __div__ def __neg__(self): return self.__class__(-self.x, -self.y) def __add__(self, other): from jax_galsim.bounds import Bounds if isinstance(other, Bounds): return other + self if not isinstance(other, Position): raise TypeError("Can only add a Position to a %s" % self.__class__.__name__) elif isinstance(other, self.__class__): return self.__class__(self.x + other.x, self.y + other.y) else: return PositionD(self.x + other.x, self.y + other.y) def __sub__(self, other): if not isinstance(other, Position): raise TypeError( "Can only subtract a Position from a %s" % self.__class__.__name__ ) elif isinstance(other, self.__class__): return self.__class__(self.x - other.x, self.y - other.y) else: return PositionD(self.x - other.x, self.y - other.y) def __repr__(self): return "galsim.%s(x=%r, y=%r)" % ( self.__class__.__name__, ensure_hashable(self.x), ensure_hashable(self.y), ) def __str__(self): return "galsim.%s(%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.x), ensure_hashable(self.y), ) def __eq__(self, other): return self is other or ( isinstance(other, self.__class__) and self.x == other.x and self.y == other.y ) def __ne__(self, other): return not self.__eq__(other) def __hash__(self): return hash( (self.__class__.__name__, ensure_hashable(self.x), ensure_hashable(self.y)) )
[docs] @implements(_galsim.Position.shear) def shear(self, shear): shear_mat = shear.getMatrix() shear_pos = jnp.dot(shear_mat, self._array) return PositionD(shear_pos[0], shear_pos[1])
[docs] def round(self): """Return the rounded-off PositionI version of this position.""" return PositionI(jnp.round(self.x), jnp.round(self.y))
[docs] def tree_flatten(self): """This function flattens the GSObject into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing children = (self.x, self.y) return (children, None)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" del aux_data obj = object.__new__(cls) obj.x = children[0] obj.y = children[1] return obj
[docs] @classmethod def from_galsim(cls, galsim_position): """Create a jax_galsim `PositionD/I` from a `galsim.PositionD/I` object.""" if isinstance(galsim_position, _galsim.PositionD): _cls = PositionD elif isinstance(galsim_position, _galsim.PositionI): _cls = PositionI else: raise TypeError( "galsim_position must be either a %s or a %s" % (_galsim.PositionD.__name__, _galsim.PositionI.__name__) ) return _cls(galsim_position.x, galsim_position.y)
[docs] def to_galsim(self): """Create a galsim `PositionD/I` from a `jax_galsim.PositionD/I` object.""" if isinstance(self, PositionI): gs_class = _galsim.bounds.PositionI cast = int else: gs_class = _galsim.bounds.PositionD cast = float return gs_class( cast(self.x), cast(self.y), )
[docs] @implements(_galsim.PositionD) @register_pytree_node_class class PositionD(Position): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) # Force conversion to float type in this case self.x = cast_to_float(self.x) self.y = cast_to_float(self.y) def _check_scalar(self, other, op): try: if ( isinstance(other, jax.Array) and other.shape == () and other.dtype.name in ["float32", "float64", "float"] ): return elif other == float(other): return except (TypeError, ValueError): pass raise TypeError("Can only %s a PositionD by float values" % op)
[docs] @implements(_galsim.PositionI) @register_pytree_node_class class PositionI(Position): def __init__(self, *args, **kwargs): self._parse_args(*args, **kwargs) # inputs must be ints self.x = cast_to_int(self.x) self.y = cast_to_int(self.y) def _check_scalar(self, other, op): try: if ( isinstance(other, jax.Array) and other.shape == () and other.dtype.name in ["int32", "int64", "int"] ): return elif other == int(other): return except (TypeError, ValueError): pass raise TypeError("Can only %s a PositionI by int values" % op)