# original source license:
#
# Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import galsim as _galsim
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_galsim.core.utils import cast_to_float, ensure_hashable, implements
[docs]
@implements(_galsim.AngleUnit)
@register_pytree_node_class
class AngleUnit(object):
valid_names = ["rad", "deg", "hr", "hour", "arcmin", "arcsec"]
def __init__(self, value):
if isinstance(value, AngleUnit):
raise TypeError("Cannot construct AngleUnit from another AngleUnit")
self._value = cast_to_float(value)
@property
@implements(_galsim.AngleUnit.value)
def value(self):
return self._value
def __rmul__(self, theta):
"""float * AngleUnit returns an Angle"""
return Angle(theta, self)
def __div__(self, unit):
"""AngleUnit / AngleUnit returns a float giving the relative scaling.
Note: At least to within machine precision, it is the case that
(x * angle_unit1) / angle_unit2 == x * (angle_unit1 / angle_unit2)
"""
if not isinstance(unit, AngleUnit):
raise TypeError("Cannot divide AngleUnit by %s" % unit)
return self.value / unit.value
__truediv__ = __div__
[docs]
@staticmethod
@implements(_galsim.AngleUnit.from_name)
def from_name(unit):
unit = unit.strip().lower()
if unit.startswith("rad"):
return radians
elif unit.startswith("deg"):
return degrees
elif unit.startswith("hour"):
return hours
elif unit.startswith("hr"):
return hours
elif unit.startswith("arcmin"):
return arcmin
elif unit.startswith("arcsec"):
return arcsec
else:
raise ValueError("Unknown Angle unit: %s" % unit)
def __repr__(self):
if self == radians:
return "galsim.radians"
elif self == degrees:
return "galsim.degrees"
elif self == hours:
return "galsim.hours"
elif self == arcmin:
return "galsim.arcmin"
elif self == arcsec:
return "galsim.arcsec"
else:
return "galsim.AngleUnit(%r)" % (ensure_hashable(self.value),)
def __eq__(self, other):
return isinstance(other, AngleUnit) and jnp.array_equal(self.value, other.value)
def __ne__(self, other):
return not self.__eq__(other)
def __hash__(self):
return hash(("galsim.AngleUnit", ensure_hashable(self.value)))
[docs]
def tree_flatten(self):
"""This function flattens the AngleUnit into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self._value,)
# Define auxiliary static data that doesn’t need to be traced
aux_data = None
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
return cls(children[0])
# Convenient pre-set built-in units
# (These are typically the only ones we will use.)
radians = AngleUnit(1.0)
hours = AngleUnit(jnp.pi / 12.0)
degrees = AngleUnit(jnp.pi / 180.0)
arcmin = AngleUnit(jnp.pi / 10800.0)
arcsec = AngleUnit(jnp.pi / 648000.0)
[docs]
@implements(_galsim.Angle)
@register_pytree_node_class
class Angle(object):
def __init__(self, theta, unit=None):
# We also want to allow angle1 = Angle(angle2) as a copy, so check for that.
if isinstance(theta, Angle):
if unit is not None:
raise TypeError(
"Cannot provide unit if theta is already an Angle instance"
)
self._rad = theta._rad
elif unit is None:
raise TypeError("Must provide unit for Angle.__init__")
elif not isinstance(unit, AngleUnit):
raise TypeError("Invalid unit %s of type %s" % (unit, type(unit)))
else:
# Normal case
self._rad = cast_to_float(theta) * unit.value
@property
@implements(_galsim.Angle.rad)
def rad(self):
return self._rad
@property
@implements(_galsim.Angle.deg)
def deg(self):
return self / degrees
def __neg__(self):
return _Angle(-self._rad)
def __pos__(self):
return self
def __abs__(self):
return _Angle(jnp.abs(self._rad))
def __add__(self, other):
if not isinstance(other, Angle):
raise TypeError(
"Cannot add %s of type %s to an Angle" % (other, type(other))
)
return _Angle(self._rad + other._rad)
def __sub__(self, other):
if not isinstance(other, Angle):
raise TypeError(
"Cannot subtract %s of type %s from an Angle" % (other, type(other))
)
return _Angle(self._rad - other._rad)
def __mul__(self, other):
return _Angle(self._rad * other)
__rmul__ = __mul__
def __div__(self, other):
if isinstance(other, AngleUnit):
return self._rad / other.value
else:
return _Angle(self._rad / other)
__truediv__ = __div__
[docs]
@implements(_galsim.Angle.wrap)
def wrap(self, center=None):
if center is None:
center = _Angle(0.0)
start = center._rad - jnp.pi
offset = (self._rad - start) // (
2.0 * jnp.pi
) # How many full cycles to subtract
return _Angle(self._rad - offset * 2.0 * jnp.pi)
[docs]
@implements(_galsim.Angle.sin)
def sin(self):
return jnp.sin(self._rad)
[docs]
@implements(_galsim.Angle.cos)
def cos(self):
return jnp.cos(self._rad)
[docs]
@implements(_galsim.Angle.tan)
def tan(self):
return jnp.tan(self._rad)
[docs]
@implements(_galsim.Angle.sincos)
def sincos(self):
sin = jnp.sin(self._rad)
cos = jnp.cos(self._rad)
return sin, cos
def __str__(self):
return str(ensure_hashable(self._rad)) + " radians"
def __repr__(self):
return "galsim.Angle(%r, galsim.radians)" % (ensure_hashable(self.rad),)
def __eq__(self, other):
return isinstance(other, Angle) and jnp.array_equal(self.rad, other.rad)
def __ne__(self, other):
return not self.__eq__(other)
def __le__(self, other):
if not isinstance(other, Angle):
raise TypeError(
"Cannot compare %s of type %s to an Angle" % (other, type(other))
)
return self._rad <= other._rad
def __lt__(self, other):
if not isinstance(other, Angle):
raise TypeError(
"Cannot compare %s of type %s to an Angle" % (other, type(other))
)
return self._rad < other._rad
def __ge__(self, other):
if not isinstance(other, Angle):
raise TypeError(
"Cannot compare %s of type %s to an Angle" % (other, type(other))
)
return self._rad >= other._rad
def __gt__(self, other):
if not isinstance(other, Angle):
raise TypeError(
"Cannot compare %s of type %s to an Angle" % (other, type(other))
)
return self._rad > other._rad
def __hash__(self):
return hash(("galsim.Angle", ensure_hashable(self._rad)))
@staticmethod
def _make_dms_string(decimal, sep, prec, pad, plus_sign):
# Account for the sign properly
if decimal < 0:
sign = "-"
decimal = -decimal
elif plus_sign:
sign = "+"
else:
sign = ""
# Figure out the 3 sep tokens
sep1 = sep2 = ""
sep3 = None
if len(sep) == 1:
sep1 = sep2 = sep
elif len(sep) == 2:
sep1, sep2 = sep
elif len(sep) == 3:
sep1, sep2, sep3 = sep
# Round to nearest 1.e-8 seconds (or 10**-prec if given)
round_prec = 8 if prec is None else prec
digits = 10**round_prec
decimal = int(3600 * digits * decimal + 0.5)
d = decimal // (3600 * digits)
decimal -= d * (3600 * digits)
m = decimal // (60 * digits)
decimal -= m * (60 * digits)
s = decimal // digits
decimal -= s * digits
# Make the string
if pad:
d_str = "%02d" % d
m_str = "%02d" % m
s_str = "%02d" % s
else:
d_str = "%d" % d
m_str = "%d" % m
s_str = "%d" % s
string = "%s%s%s%s%s%s.%0*d" % (
sign,
d_str,
sep1,
m_str,
sep2,
s_str,
round_prec,
decimal,
)
if not prec:
string = string.rstrip("0")
string = string.rstrip(".")
if sep3:
string = string + sep3
return string
[docs]
@implements(_galsim.Angle.hms)
def hms(self, sep=":", prec=None, pad=True, plus_sign=False):
if not len(sep) <= 3:
raise ValueError("sep must be a string or tuple of length <= 3")
if prec is not None and not prec >= 0:
raise ValueError("prec must be >= 0")
return self._make_dms_string(self / hours, sep, prec, pad, plus_sign)
[docs]
@implements(_galsim.Angle.dms)
def dms(self, sep=":", prec=None, pad=True, plus_sign=False):
if not len(sep) <= 3:
raise ValueError("sep must be a string or tuple of length <= 3")
if prec is not None and not prec >= 0:
raise ValueError("prec must be >= 0")
return self._make_dms_string(self / degrees, sep, prec, pad, plus_sign)
[docs]
@staticmethod
@implements(_galsim.Angle.from_hms)
def from_hms(str):
return Angle._parse_dms(str) * hours
[docs]
@staticmethod
@implements(_galsim.Angle.from_dms)
def from_dms(str):
return Angle._parse_dms(str) * degrees
@staticmethod
def _parse_dms(dms):
"""Convert a string of the form dd:mm:ss.decimal into decimal degrees."""
import re
tokens = tuple(filter(None, re.split(r"([\.\d]+)", dms.strip())))
if len(tokens) <= 1:
raise ValueError("string is not of the expected format")
sign = 1
try:
dd = float(tokens[0])
except ValueError:
if tokens[0].strip() == "-":
sign = -1
tokens = tokens[1:]
dd = float(tokens[0])
if len(tokens) <= 1:
raise ValueError("string is not of the expected format")
if len(tokens) <= 2:
return sign * dd
mm = float(tokens[2])
if len(tokens) <= 4:
return sign * (dd + mm / 60)
if len(tokens) >= 7:
raise ValueError("string is not of the expected format")
ss = float(tokens[4])
return sign * (dd + mm / 60.0 + ss / 3600.0)
[docs]
def tree_flatten(self):
"""This function flattens the Angle into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self._rad,)
# Define auxiliary static data that doesn’t need to be traced
aux_data = None
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
ret = cls.__new__(cls)
ret._rad = children[0]
return ret
[docs]
@staticmethod
def from_galsim(gs_angle):
"""Create a jax_galsim `Angle` from a `galsim.Angle` object."""
return _Angle(gs_angle._rad)
[docs]
def to_galsim(self):
"""Create a galsim `Angle` from a `jax_galsim.Angle` object."""
return _galsim.angle._Angle(float(self._rad))
@implements(_galsim._Angle)
def _Angle(theta):
ret = Angle.__new__(Angle)
ret._rad = theta
return ret