# original source license:
#
# Copyright (c) 2013-2017 LSST Dark Energy Science Collaboration (DESC)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import warnings
from functools import partial
import coord as _coord
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_galsim.angle import Angle, _Angle, arcsec, degrees, radians
from jax_galsim.core.utils import ensure_hashable, implements
# we have to copy this one since JAX sends in `t` as a traced array
# and the coord.Angle classes don't know how to handle that
@implements(_coord.util.ecliptic_obliquity)
def _ecliptic_obliquity(epoch):
# We need to figure out the time in Julian centuries from J2000 for this epoch.
t = (epoch - 2000.0) / 100.0
# Then we use the last (most recent) formula listed under
# http://en.wikipedia.org/wiki/Ecliptic#Obliquity_of_the_ecliptic, from
# JPL's 2010 calculations.
ep = Angle.from_dms("23:26:21.406")
ep -= Angle.from_dms("00:00:46.836769") * t
ep -= Angle.from_dms("00:00:0.0001831") * (t**2)
ep += Angle.from_dms("00:00:0.0020034") * (t**3)
# There are even higher order terms, but they are probably not important for any reasonable
# calculation someone would do with this package.
return ep
def _sun_position_ecliptic(date):
return _Angle(_coord.util.sun_position_ecliptic(date).rad)
[docs]
@implements(
_galsim.celestial.CelestialCoord,
lax_description=(
"The JAX version of this object does not check that the declination is between -90 and 90."
),
)
@register_pytree_node_class
class CelestialCoord(object):
def __init__(self, ra, dec=None):
if isinstance(ra, CelestialCoord) and dec is None:
# Copy constructor
self._ra = ra._ra
self._dec = ra._dec
elif ra is None or dec is None:
raise TypeError("ra and dec are both required")
elif not isinstance(ra, Angle):
raise TypeError("ra must be a galsim.Angle")
elif not isinstance(dec, Angle):
raise TypeError("dec must be a galsim.Angle")
else:
# Normal case
self._ra = ra
self._dec = dec
@property
@implements(_galsim.celestial.CelestialCoord.ra)
def ra(self):
return self._ra
@property
@implements(_galsim.celestial.CelestialCoord.dec)
def dec(self):
return self._dec
@property
@implements(_galsim.celestial.CelestialCoord.rad)
def rad(self):
return (self._ra.rad, self._dec.rad)
@jax.jit
def _get_aux(self):
_sindec, _cosdec = self._dec.sincos()
_sinra, _cosra = self._ra.sincos()
_x = _cosdec * _cosra
_y = _cosdec * _sinra
_z = _sindec
return _cosra, _sinra, _cosdec, _sindec, _x, _y, _z
# DO NOT ACUTALLY USE THIS, HERE FOR TESTING PURPOSES ONLY
def _set_aux(self):
aux = self._get_aux()
(
self._cosra,
self._sinra,
self._cosdec,
self._sindec,
self._x,
self._y,
self._z,
) = aux
[docs]
@implements(_galsim.celestial.CelestialCoord.get_xyz)
def get_xyz(self):
return self._get_aux()[4:]
[docs]
@staticmethod
@jax.jit
@implements(
_galsim.celestial.CelestialCoord.from_xyz,
lax_description=(
"The JAX version of this static method does not check that the norm of the input "
"vector is non-zero."
),
)
def from_xyz(x, y, z):
norm = jnp.sqrt(x * x + y * y + z * z)
ret = CelestialCoord.__new__(CelestialCoord)
ret._x = x / norm
ret._y = y / norm
ret._z = z / norm
ret._sindec = ret._z
ret._cosdec = jnp.sqrt(ret._x * ret._x + ret._y * ret._y)
ret._sinra = jnp.where(
ret._cosdec == 0,
0,
ret._y / ret._cosdec,
)
ret._cosra = jnp.where(
ret._cosdec == 0,
1.0,
ret._x / ret._cosdec,
)
ret._ra = (jnp.arctan2(ret._sinra, ret._cosra) * radians).wrap(_Angle(jnp.pi))
ret._dec = jnp.arctan2(ret._sindec, ret._cosdec) * radians
return ret
[docs]
@staticmethod
@jax.jit
@implements(_galsim.celestial.CelestialCoord.radec_to_xyz)
def radec_to_xyz(ra, dec, r=1.0):
cosdec = jnp.cos(dec)
x = cosdec * jnp.cos(ra) * r
y = cosdec * jnp.sin(ra) * r
z = jnp.sin(dec) * r
return x, y, z
[docs]
@staticmethod
@partial(jax.jit, static_argnames=("return_r",))
@implements(_galsim.celestial.CelestialCoord.xyz_to_radec)
def xyz_to_radec(x, y, z, return_r=False):
xy2 = x**2 + y**2
ra = jnp.arctan2(y, x)
# Note: We don't need arctan2, since always quadrant 1 or 4.
# Using plain arctan is slightly faster. About 10% for the whole function.
# However, if any points have x=y=0, then this will raise a numpy warning.
# It still gives the right answer, but we catch and ignore the warning here.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
dec = jnp.arctan(z / jnp.sqrt(xy2))
if return_r:
return ra, dec, jnp.sqrt(xy2 + z**2)
else:
return ra, dec
[docs]
@implements(_galsim.celestial.CelestialCoord.normal)
def normal(self):
return _CelestialCoord(self.ra.wrap(_Angle(jnp.pi)), self.dec)
@staticmethod
@jax.jit
def _raw_dsq(auxc1, auxc2):
# Compute the raw dsq between two coordinates.
c1_x, c1_y, c1_z = auxc1[4:]
c2_x, c2_y, c2_z = auxc2[4:]
return (c1_x - c2_x) ** 2 + (c1_y - c2_y) ** 2 + (c1_z - c2_z) ** 2
@staticmethod
@jax.jit
def _raw_cross(auxc1, auxc2):
# Compute the raw cross product between two coordinates.
c1_x, c1_y, c1_z = auxc1[4:]
c2_x, c2_y, c2_z = auxc2[4:]
return (
c1_y * c2_z - c2_y * c1_z,
c1_z * c2_x - c2_z * c1_x,
c1_x * c2_y - c2_x * c1_y,
)
[docs]
@implements(_galsim.celestial.CelestialCoord.distanceTo)
@jax.jit
def distanceTo(self, coord2):
# The easiest way to do this in a way that is stable for small separations
# is to calculate the (x,y,z) position on the unit sphere corresponding to each
# coordinate position.
#
# x = cos(dec) cos(ra)
# y = cos(dec) sin(ra)
# z = sin(dec)
aux = self._get_aux()
auxc = coord2._get_aux()
# The direct distance between the two points is
#
# d^2 = (x1-x2)^2 + (y1-y2)^2 + (z1-z2)^2
dsq = self._raw_dsq(aux, auxc)
theta = jnp.where(
dsq < 3.99,
# (The usual case. This formula is perfectly stable here.)
# This direct distance can then be converted to a great circle distance via
#
# sin(theta/2) = d/2
2.0 * jnp.arcsin(0.5 * jnp.sqrt(dsq)),
# Points are nearly antipodes where the accuracy of this formula starts to break down.
# But in this case, the cross product provides an accurate distance.
jnp.pi
- jnp.arcsin(jnp.sqrt(jnp.sum(jnp.array(self._raw_cross(aux, auxc)) ** 2))),
)
return _Angle(theta)
[docs]
@implements(
_galsim.celestial.CelestialCoord.greatCirclePoint,
lax_description=(
"The JAX version of this method does not check that coord2 defines a unique great "
"circle with the current coord at angle theta."
),
)
@jax.jit
def greatCirclePoint(self, coord2, theta):
aux = self._get_aux()
auxc = coord2._get_aux()
# Define u = self
# v = coord2
# w = (u x v) x u
# The great circle through u and v is then
#
# R(t) = u cos(t) + w sin(t)
#
# Rather than directly calculate (u x v) x u, let's do some simplification first.
# u x v = ( uy vz - uz vy )
# ( uz vx - ux vz )
# ( ux vy - uy vx )
# wx = (u x v)_y uz - (u x v)_z uy
# = (uz vx - ux vz) uz - (ux vy - uy vx) uy
# = vx uz^2 - vz ux uz - vy ux uy + vx uy^2
# = vx (1 - ux^2) - ux (uz vz + uy vy)
# = vx - ux (u . v)
# = vx - ux (1 - d^2/2)
# = vx - ux + ux d^2/2
# wy = vy - uy + uy d^2/2
# wz = vz - uz + uz d^2/2
dsq = self._raw_dsq(aux, auxc)
# These are unnormalized yet.
_x, _y, _z = aux[4:]
c_x, c_y, c_z = auxc[4:]
wx = c_x - _x + _x * dsq / 2.0
wy = c_y - _y + _y * dsq / 2.0
wz = c_z - _z + _z * dsq / 2.0
# Normalize
wr = (wx**2 + wy**2 + wz**2) ** 0.5
# if wr == 0.:
# raise ValueError("coord2 does not define a unique great circle with self.")
wx /= wr
wy /= wr
wz /= wr
# R(theta)
s, c = theta.sincos()
rx = _x * c + wx * s
ry = _y * c + wy * s
rz = _z * c + wz * s
return CelestialCoord.from_xyz(rx, ry, rz)
@jax.jit
def _triple(self, aux, auxc2, auxc3):
"""Compute the scalar triple product of the three vectors:
(A x C). B = sina sinb sinC
where C = self, A = coord2, B = coord3. This is used by both angleBetween and area.
(Although note that the triple product is invariant to the ordering modulo a sign.)
"""
_x, _y, _z = aux[4:]
c2_x, c2_y, c2_z = auxc2[4:]
c3_x, c3_y, c3_z = auxc3[4:]
# Note, the scalar triple product, (AxC).B, is the determinant of the 3x3 matrix
# [ xA yA zA ]
# [ xC yC zC ]
# [ xB yB zB ]
# Furthermore, it is more stable to calculate it that way than computing the cross
# product by hand and then dotting it to the other vector.
# JAX has separate code path for 3x3 determinants that doesn't match the LU path in
# galsim/ numpy. The slogdet function uses the LU decomp by default, so we use that.
sign, logdet = jnp.linalg.slogdet(
jnp.array(
[[c2_x, c2_y, c2_z], [_x, _y, _z], [c3_x, c3_y, c3_z]],
dtype=float,
)
)
return sign * jnp.exp(logdet)
@jax.jit
def _alt_triple(self, aux, auxc2, auxc3):
"""Compute a different triple product of the three vectors:
(A x C). (B x C) = sina sinb cosC
where C = self, A = coord2, B = coord3. This is used by both angleBetween and area.
"""
# We can simplify (AxC).(BxC) as follows:
# (A x C) . (B x C)
# = (C x (BxC)) . A Rotation of triple product with (BxC) one of the vectors
# = ((C.C)B - (C.B)C) . A Vector triple product identity
# = A.B - (A.C) (B.C) C.C = 1
# Dot products for nearby coordinates are not very accurate. Better to use the distances
# between the points: A.B = 1 - d_AB^2/2
# = 1 - d_AB^2/2 - (1-d_AC^2/2) (1-d_BC^2/2)
# = d_AC^2 / 2 + d_BC^2 / 2 - d_AB^2 / 2 - d_AC^2 d_BC^2 / 4
dsq_AC = self._raw_dsq(aux, auxc2)
dsq_BC = self._raw_dsq(aux, auxc3)
dsq_AB = self._raw_dsq(auxc2, auxc3)
return 0.5 * (dsq_AC + dsq_BC - dsq_AB - 0.5 * dsq_AC * dsq_BC)
[docs]
@implements(_galsim.celestial.CelestialCoord.angleBetween)
@jax.jit
def angleBetween(self, coord2, coord3):
# Call A = coord2, B = coord3, C = self
# Then we are looking for the angle ACB.
# If we treat each coord as a (x,y,z) vector, then we can use the following spherical
# trig identities:
#
# (A x C) . B = sina sinb sinC
# (A x C) . (B x C) = sina sinb cosC
#
# Then we can just use atan2 to find C, and atan2 automatically gets the sign right.
# And we only need 1 trig call, assuming that x,y,z are already set up, which is often
# the case.
aux = self._get_aux()
auxc2 = coord2._get_aux()
auxc3 = coord3._get_aux()
sinC = self._triple(aux, auxc2, auxc3)
cosC = self._alt_triple(aux, auxc2, auxc3)
C = jnp.arctan2(sinC, cosC)
return _Angle(C)
[docs]
@implements(_galsim.celestial.CelestialCoord.area)
@jax.jit
def area(self, coord2, coord3):
# The area of a spherical triangle is defined by the "spherical excess", E.
# There are several formulae for E:
# (cf. http://en.wikipedia.org/wiki/Spherical_trigonometry#Area_and_spherical_excess)
#
# E = A + B + C - pi
# tan(E/4) = sqrt(tan(s/2) tan((s-a)/2) tan((s-b)/2) tan((s-c)/2)
# tan(E/2) = tan(a/2) tan(b/2) sin(C) / (1 + tan(a/2) tan(b/2) cos(C))
#
# We use the last formula, which is stable both for small triangles and ones that are
# nearly degenerate (which the middle formula may have trouble with).
#
# Furthermore, we can use some of the math for angleBetween and distanceTo to simplify
# this further:
#
# In angleBetween, we have formulae for sina sinb sinC and sina sinb cosC.
# In distanceTo, we have formulae for sin(a/2) and sin(b/2).
#
# Define: F = sina sinb sinC
# G = sina sinb cosC
# da = 2 sin(a/2)
# db = 2 sin(b/2)
#
# tan(E/2) = sin(a/2) sin(b/2) sin(C) / (cos(a/2) cos(b/2) + sin(a/2) sin(b/2) cos(C))
# = sin(a) sin(b) sin(C) / (4 cos(a/2)^2 cos(b/2)^2 + sin(a) sin(b) cos(C))
# = F / (4 (1-sin(a/2)^2) (1-sin(b/2)^2) + G)
# = F / (4-da^2) (4-db^2)/4 + G)
aux = self._get_aux()
auxc2 = coord2._get_aux()
auxc3 = coord3._get_aux()
F = self._triple(aux, auxc2, auxc3)
G = self._alt_triple(aux, auxc2, auxc3)
dasq = self._raw_dsq(aux, auxc2)
dbsq = self._raw_dsq(aux, auxc3)
tanEo2 = F / (0.25 * (4.0 - dasq) * (4.0 - dbsq) + G)
E = 2.0 * jnp.arctan(jnp.abs(tanEo2))
return E
_valid_projections = [None, "gnomonic", "stereographic", "lambert", "postel"]
[docs]
@implements(_galsim.celestial.CelestialCoord.project)
def project(self, coord2, projection=None):
if projection not in CelestialCoord._valid_projections:
raise ValueError("Unknown projection: %s" % projection)
# The core calculation is done in a helper function:
u, v = self._project(coord2._get_aux(), projection)
return u * radians, v * radians
[docs]
@implements(_galsim.celestial.CelestialCoord.project_rad)
def project_rad(self, ra, dec, projection=None):
if projection not in CelestialCoord._valid_projections:
raise ValueError("Unknown projection: %s" % projection)
cosra = jnp.cos(ra)
sinra = jnp.sin(ra)
cosdec = jnp.cos(dec)
sindec = jnp.sin(dec)
return self._project((cosra, sinra, cosdec, sindec), projection)
@partial(jax.jit, static_argnums=(2,))
def _project(self, auxc, projection):
cosra, sinra, cosdec, sindec = auxc[:4]
_cosra, _sinra, _cosdec, _sindec = self._get_aux()[:4]
# The equations are given at the above mathworld websites. They are the same except
# for the definition of k:
#
# x = k cos(dec) sin(ra-ra0)
# y = k ( cos(dec0) sin(dec) - sin(dec0) cos(dec) cos(ra-ra0) )
#
# Lambert:
# k = sqrt( 2 / ( 1 + cos(c) ) )
# Stereographic:
# k = 2 / ( 1 + cos(c) )
# Gnomonic:
# k = 1 / cos(c)
# Postel:
# k = c / sin(c)
# where cos(c) = sin(dec0) sin(dec) + cos(dec0) cos(dec) cos(ra-ra0)
# cos(dra) = cos(ra-ra0) = cos(ra0) cos(ra) + sin(ra0) sin(ra)
cosdra = _cosra * cosra
cosdra += _sinra * sinra
# sin(dra) = -sin(ra - ra0)
# Note: - sign here is to make +x correspond to -ra,
# so x increases for decreasing ra.
# East is to the left on the sky!
# sin(dra) = -cos(ra0) sin(ra) + sin(ra0) cos(ra)
sindra = _sinra * cosra
sindra -= _cosra * sinra
# Calculate k according to which projection we are using
cosc = cosdec * cosdra
cosc *= _cosdec
cosc += _sindec * sindec
if projection is None or projection[0] == "g":
k = 1.0 / cosc
elif projection[0] == "s":
k = 2.0 / (1.0 + cosc)
elif projection[0] == "l":
k = jnp.sqrt(2.0 / (1.0 + cosc))
else:
c = jnp.arccos(cosc)
# k = c / np.sin(c)
# np.sinc is defined as sin(pi x) / (pi x)
# So need to divide by pi first.
k = 1.0 / jnp.sinc(c / jnp.pi)
# u = k * cosdec * sindra
# v = k * ( self._cosdec * sindec - self._sindec * cosdec * cosdra )
u = cosdec * sindra
v = cosdec * cosdra
v *= -_sindec
v += _cosdec * sindec
u *= k
v *= k
return u, v
[docs]
@implements(_galsim.celestial.CelestialCoord.deproject)
def deproject(self, u, v, projection=None):
if projection not in CelestialCoord._valid_projections:
raise ValueError("Unknown projection: %s" % projection)
# Again, do the core calculations in a helper function
ra, dec = self._deproject(u / radians, v / radians, projection)
return CelestialCoord(_Angle(ra), _Angle(dec))
[docs]
@implements(_galsim.celestial.CelestialCoord.deproject_rad)
def deproject_rad(self, u, v, projection=None):
if projection not in CelestialCoord._valid_projections:
raise ValueError("Unknown projection: %s" % projection)
return self._deproject(u, v, projection)
@partial(jax.jit, static_argnums=(3,))
def _deproject(self, u, v, projection):
# The inverse equations are also given at the same web sites:
#
# sin(dec) = cos(c) sin(dec0) + v sin(c) cos(dec0) / r
# tan(ra-ra0) = u sin(c) / (r cos(dec0) cos(c) - v sin(dec0) sin(c))
#
# where
#
# r = sqrt(u^2+v^2)
# c = tan^(-1)(r) for gnomonic
# c = 2 tan^(-1)(r/2) for stereographic
# c = 2 sin^(-1)(r/2) for lambert
# c = r for postel
# Note that we can rewrite the formulae as:
#
# sin(dec) = cos(c) sin(dec0) + v (sin(c)/r) cos(dec0)
# tan(ra-ra0) = u (sin(c)/r) / (cos(dec0) cos(c) - v sin(dec0) (sin(c)/r))
#
# which means we only need cos(c) and sin(c)/r. For most of the projections,
# this saves us from having to take sqrt(rsq).
rsq = u * u
rsq += v * v
if projection is None or projection[0] == "g":
# c = arctan(r)
# cos(c) = 1 / sqrt(1+r^2)
# sin(c) = r / sqrt(1+r^2)
cosc = sinc_over_r = 1.0 / jnp.sqrt(1.0 + rsq)
elif projection[0] == "s":
# c = 2 * arctan(r/2)
# Some trig manipulations reveal:
# cos(c) = (4-r^2) / (4+r^2)
# sin(c) = 4r / (4+r^2)
cosc = (4.0 - rsq) / (4.0 + rsq)
sinc_over_r = 4.0 / (4.0 + rsq)
elif projection[0] == "l":
# c = 2 * arcsin(r/2)
# Some trig manipulations reveal:
# cos(c) = 1 - r^2/2
# sin(c) = r sqrt(4-r^2) / 2
cosc = 1.0 - rsq / 2.0
sinc_over_r = jnp.sqrt(4.0 - rsq) / 2.0
else:
r = jnp.sqrt(rsq)
cosc = jnp.cos(r)
sinc_over_r = jnp.sinc(r / jnp.pi)
# Compute sindec, tandra
# Note: more efficient to use numpy op= as much as possible to avoid temporary arrays.
_cosra, _sinra, _cosdec, _sindec = self._get_aux()[:4]
# sindec = cosc * self._sindec + v * sinc_over_r * self._cosdec
sindec = v * sinc_over_r
sindec *= _cosdec
sindec += cosc * _sindec
# Remember the - sign so +dra is -u. East is left.
tandra_num = u * sinc_over_r
tandra_num *= -1.0
# tandra_denom = cosc * self._cosdec - v * sinc_over_r * self._sindec
tandra_denom = v * sinc_over_r
tandra_denom *= -_sindec
tandra_denom += cosc * _cosdec
dec = jnp.arcsin(sindec)
ra = self.ra.rad + jnp.arctan2(tandra_num, tandra_denom)
return ra, dec
[docs]
@implements(_galsim.celestial.CelestialCoord.jac_deproject)
def jac_deproject(self, u, v, projection=None):
if projection not in CelestialCoord._valid_projections:
raise ValueError("Unknown projection: %s" % projection)
return self._jac_deproject(u.rad, v.rad, projection)
[docs]
@implements(_galsim.celestial.CelestialCoord.jac_deproject_rad)
def jac_deproject_rad(self, u, v, projection=None):
if projection not in CelestialCoord._valid_projections:
raise ValueError("Unknown projection: %s" % projection)
return self._jac_deproject(u, v, projection)
@partial(jax.jit, static_argnums=(3,))
def _jac_deproject(self, u, v, projection):
# sin(dec) = cos(c) sin(dec0) + v sin(c)/r cos(dec0)
# tan(ra-ra0) = u sin(c)/r / (cos(dec0) cos(c) - v sin(dec0) sin(c)/r)
#
# d(sin(dec)) = cos(dec) ddec = s0 dc + (v ds + s dv) c0
# dtan(ra-ra0) = sec^2(ra-ra0) dra
# = ( (u ds + s du) A - u s (dc c0 - (v ds + s dv) s0 ) )/A^2
# where s = sin(c) / r
# c = cos(c)
# s0 = sin(dec0)
# c0 = cos(dec0)
# A = c c0 - v s s0
rsq = u * u + v * v
# rsq1 = (u + 1.e-4)**2 + v**2
# rsq2 = u**2 + (v + 1.e-4)**2
if projection is None or projection[0] == "g":
c = s = 1.0 / jnp.sqrt(1.0 + rsq)
s3 = s * s * s
dcdu = dsdu = -u * s3
dcdv = dsdv = -v * s3
elif projection[0] == "s":
s = 4.0 / (4.0 + rsq)
c = 2.0 * s - 1.0
ssq = s * s
dcdu = -u * ssq
dcdv = -v * ssq
dsdu = 0.5 * dcdu
dsdv = 0.5 * dcdv
elif projection[0] == "l":
c = 1.0 - rsq / 2.0
s = jnp.sqrt(4.0 - rsq) / 2.0
dcdu = -u
dcdv = -v
dsdu = -u / (4.0 * s)
dsdv = -v / (4.0 * s)
else:
r = jnp.sqrt(rsq)
# original code for reference
# if r == 0.:
# c = s = 1
# dcdu = -u
# dcdv = -v
# dsdu = dsdv = 0
# else:
# c = np.cos(r)
# s = np.sin(r)/r
# dcdu = -s*u
# dcdv = -s*v
# dsdu = (c-s)*u/rsq
# dsdv = (c-s)*v/rsq
c = jnp.where(
r == 0.0,
1.0,
jnp.cos(r),
)
s = jnp.where(
r == 0.0,
1.0,
jnp.sin(r) / r,
)
dcdu = jnp.where(
r == 0.0,
-u,
-s * u,
)
dcdv = jnp.where(
r == 0.0,
-v,
-s * v,
)
dsdu = jnp.where(
r == 0.0,
0.0,
(c - s) * u / rsq,
)
dsdv = jnp.where(
r == 0.0,
0.0,
(c - s) * v / rsq,
)
_cosra, _sinra, _cosdec, _sindec = self._get_aux()[:4]
s0 = _sindec
c0 = _cosdec
sindec = c * s0 + v * s * c0
cosdec = jnp.sqrt(1.0 - sindec * sindec)
dddu = (s0 * dcdu + v * dsdu * c0) / cosdec
dddv = (s0 * dcdv + (v * dsdv + s) * c0) / cosdec
tandra_num = u * s
tandra_denom = c * c0 - v * s * s0
# Note: A^2 sec^2(dra) = denom^2 (1 + tan^2(dra) = denom^2 + num^2
A2sec2dra = tandra_denom**2 + tandra_num**2
drdu = (
(u * dsdu + s) * tandra_denom - u * s * (dcdu * c0 - v * dsdu * s0)
) / A2sec2dra
drdv = (
u * dsdv * tandra_denom - u * s * (dcdv * c0 - (v * dsdv + s) * s0)
) / A2sec2dra
drdu *= cosdec
drdv *= cosdec
return jnp.array([[drdu, drdv], [dddu, dddv]])
[docs]
@implements(_galsim.celestial.CelestialCoord.precess)
def precess(self, from_epoch, to_epoch):
return CelestialCoord._precess(
from_epoch, to_epoch, self._ra.rad, self._dec.rad
)
[docs]
@implements(_galsim.celestial.CelestialCoord.galactic)
def galactic(self, epoch=2000.0):
# cf. Lang, Astrophysical Formulae, page 13
# cos(b) cos(el-33) = cos(dec) cos(ra-282.25)
# cos(b) sin(el-33) = sin(dec) sin(62.6) + cos(dec) sin(ra-282.25) cos(62.6)
# sin(b) = sin(dec) cos(62.6) - cos(dec) sin(ra-282.25) sin(62.6)
#
# Those formulae were for the 1950 epoch. The corresponding numbers for J2000 are:
# (cf. https://arxiv.org/pdf/1010.3773.pdf)
el0 = 32.93191857 * degrees
r0 = 282.859481208 * degrees
d0 = 62.8717488056 * degrees
sind0, cosd0 = d0.sincos()
sind, cosd = self.dec.sincos()
sinr, cosr = (self.ra - r0).sincos()
cbcl = cosd * cosr
cbsl = sind * sind0 + cosd * sinr * cosd0
sb = sind * cosd0 - cosd * sinr * sind0
b = _Angle(jnp.arcsin(sb))
el = (_Angle(jnp.arctan2(cbsl, cbcl)) + el0).wrap(_Angle(jnp.pi))
return (el, b)
[docs]
@staticmethod
@implements(_galsim.celestial.CelestialCoord.from_galactic)
def from_galactic(el, b, epoch=2000.0):
el0 = 32.93191857 * degrees
r0 = 282.859481208 * degrees
d0 = 62.8717488056 * degrees
sind0, cosd0 = d0.sincos()
sinb, cosb = b.sincos()
sinl, cosl = (el - el0).sincos()
x1 = cosb * cosl
y1 = cosb * sinl
z1 = sinb
x2 = x1
y2 = y1 * cosd0 - z1 * sind0
z2 = y1 * sind0 + z1 * cosd0
temp = CelestialCoord.from_xyz(x2, y2, z2)
return CelestialCoord(temp.ra + r0, temp.dec).normal()
[docs]
@partial(jax.jit, static_argnames=("date",))
@implements(_galsim.celestial.CelestialCoord.ecliptic)
def ecliptic(self, epoch=2000.0, date=None):
# We are going to work in terms of the (x, y, z) projections.
_x, _y, _z = self._get_aux()[4:]
# Get the obliquity of the ecliptic.
if date is not None:
epoch = date.year
ep = _ecliptic_obliquity(epoch)
sin_ep, cos_ep = ep.sincos()
# Coordinate transformation here, from celestial to ecliptic:
x_ecl = _x
y_ecl = cos_ep * _y + sin_ep * _z
z_ecl = -sin_ep * _y + cos_ep * _z
beta = _Angle(jnp.arcsin(z_ecl))
lam = _Angle(jnp.arctan2(y_ecl, x_ecl))
if date is not None:
# Find the sun position in ecliptic coordinates on this date. We have to convert to
# Julian day in order to use our helper routine to find the Sun position in ecliptic
# coordinates.
lam_sun = _sun_position_ecliptic(date)
# Subtract it off, to get ecliptic coordinates relative to the sun.
lam -= lam_sun
return (lam.wrap(), beta)
[docs]
@staticmethod
@partial(jax.jit, static_argnames=("date",))
@implements(_galsim.celestial.CelestialCoord.from_ecliptic)
def from_ecliptic(lam, beta, epoch=2000.0, date=None):
if date is not None:
lam += _sun_position_ecliptic(date)
# Get the (x, y, z)_ecliptic from (lam, beta).
sinbeta, cosbeta = beta.sincos()
sinlam, coslam = lam.sincos()
x_ecl = cosbeta * coslam
y_ecl = cosbeta * sinlam
z_ecl = sinbeta
# Get the obliquity of the ecliptic.
if date is not None:
epoch = date.year
ep = _ecliptic_obliquity(epoch)
# Transform to (x, y, z)_equatorial.
sin_ep, cos_ep = ep.sincos()
x_eq = x_ecl
y_eq = cos_ep * y_ecl - sin_ep * z_ecl
z_eq = sin_ep * y_ecl + cos_ep * z_ecl
return CelestialCoord.from_xyz(x_eq, y_eq, z_eq)
def __repr__(self):
return "galsim.CelestialCoord(%r, %r)" % (
ensure_hashable(self._ra),
ensure_hashable(self._dec),
)
def __str__(self):
return "galsim.CelestialCoord(%s, %s)" % (
ensure_hashable(self._ra),
ensure_hashable(self._dec),
)
def __hash__(self):
return hash(repr(self))
def __eq__(self, other):
return (
isinstance(other, CelestialCoord)
and jnp.array_equal(self._ra.rad, other._ra.rad)
and jnp.array_equal(self._dec.rad, other._dec.rad)
)
def __ne__(self, other):
return not self.__eq__(other)
[docs]
def tree_flatten(self):
"""This function flattens the CelestialCoord into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self._ra, self._dec)
# Define auxiliary static data that doesn’t need to be traced
aux_data = None
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
return _CelestialCoord(*children)
@jax.jit
def _precess_kern(from_epoch, to_epoch, _x, _y, _z, _ra, _dec):
# t0, t below correspond to Lieske's big T and little T
t0 = (from_epoch - 2000.0) / 100.0
t = (to_epoch - from_epoch) / 100.0
t02 = t0 * t0
t2 = t * t
t3 = t2 * t
# a,b,c below correspond to Lieske's zeta_A, z_A and theta_A
a = (
(2306.2181 + 1.39656 * t0 - 0.000139 * t02) * t
+ (0.30188 - 0.000344 * t0) * t2
+ 0.017998 * t3
) * arcsec
b = (
(2306.2181 + 1.39656 * t0 - 0.000139 * t02) * t
+ (1.09468 + 0.000066 * t0) * t2
+ 0.018203 * t3
) * arcsec
c = (
(2004.3109 - 0.85330 * t0 - 0.000217 * t02) * t
+ (-0.42665 - 0.000217 * t0) * t2
- 0.041833 * t3
) * arcsec
sina, cosa = a.sincos()
sinb, cosb = b.sincos()
sinc, cosc = c.sincos()
# This is the precession rotation matrix:
xx = cosa * cosc * cosb - sina * sinb
yx = -sina * cosc * cosb - cosa * sinb
zx = -sinc * cosb
xy = cosa * cosc * sinb + sina * cosb
yy = -sina * cosc * sinb + cosa * cosb
zy = -sinc * sinb
xz = cosa * sinc
yz = -sina * sinc
zz = cosc
# Perform the rotation:
x2 = xx * _x + yx * _y + zx * _z
y2 = xy * _x + yy * _y + zy * _z
z2 = xz * _x + yz * _y + zz * _z
return CelestialCoord.from_xyz(x2, y2, z2).normal()
@jax.jit
def _precess(from_epoch, to_epoch, _ra, _dec):
_sindec, _cosdec = jnp.sin(_dec), jnp.cos(_dec)
_sinra, _cosra = jnp.sin(_ra), jnp.cos(_ra)
_x = _cosdec * _cosra
_y = _cosdec * _sinra
_z = _sindec
return jax.lax.cond(
jnp.array_equal(from_epoch, to_epoch),
lambda *args: _CelestialCoord(_Angle(args[-2]), _Angle(args[-1])),
CelestialCoord._precess_kern,
from_epoch,
to_epoch,
_x,
_y,
_z,
_ra,
_dec,
)
[docs]
@staticmethod
def from_galsim(gcoord):
"""Create a jax_galsim `CelestialCoord` from a `galsim.CelestialCoord` object."""
return _CelestialCoord(_Angle(gcoord.ra.rad), _Angle(gcoord.dec.rad))
[docs]
def to_galsim(self):
"""Create a galsim `CelestialCoord` from a `jax_galsim.CelestialCoord` object."""
return _galsim.celestial.CelestialCoord(
self.ra.to_galsim(), self.dec.to_galsim()
)
@implements(_coord._CelestialCoord)
def _CelestialCoord(ra, dec):
ret = CelestialCoord.__new__(CelestialCoord)
ret._ra = ra
ret._dec = dec
return ret