import copy
import os
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax_galsim import fits
from jax_galsim.angle import AngleUnit, arcsec, degrees, radians
from jax_galsim.celestial import CelestialCoord
from jax_galsim.core.utils import (
cast_to_float,
cast_to_python_float,
ensure_hashable,
implements,
)
from jax_galsim.errors import (
GalSimError,
GalSimIncompatibleValuesError,
GalSimNotImplementedError,
GalSimValueError,
galsim_warn,
)
from jax_galsim.position import PositionD
from jax_galsim.utilities import horner2d
from jax_galsim.wcs import (
AffineTransform,
CelestialWCS,
JacobianWCS,
OffsetWCS,
PixelScale,
)
#########################################################################################
#
# We have the following WCS classes that know how to read the WCS from a FITS file:
#
# GSFitsWCS
#
# As for all CelestialWCS classes, they must define the following:
#
# _radec function returning (ra, dec) in _radians_ at position (x,y)
# _xy function returning (x, y) given (ra, dec) in _radians_.
# _writeHeader function that writes the WCS to a fits header.
# _readHeader static function that reads the WCS from a fits header.
# copy return a copy
# __eq__ check if this equals another WCS
#
#########################################################################################
[docs]
@implements(
_galsim.fitswcs.GSFitsWCS,
lax_description=(
"The JAX-GalSim version of this class does not raise errors if inverting the WCS to "
"map ra,dec to (x,y) fails. Instead it returns NaNs."
),
)
@register_pytree_node_class
class GSFitsWCS(CelestialWCS):
_req_params = {"file_name": str}
_opt_params = {"dir": str, "hdu": int, "origin": PositionD, "compression": str}
def __init__(
self,
file_name=None,
dir=None,
hdu=None,
header=None,
compression="auto",
origin=None,
_data=None,
):
# Note: _data is not intended for end-user use. It enables the equivalent of a
# private constructor of GSFitsWCS by the function TanWCS. The details of its
# use are intentionally not documented above.
self._color = None
self._tag = None # Write something useful here (see below). This is just used for the str.
# If _data is given, copy the data and we're done.
if _data is not None:
self.wcs_type = _data[0]
self.crpix = _data[1]
self.cd = _data[2]
self.center = _data[3]
self.pv = _data[4]
self.ab = _data[5]
self.abp = _data[6]
if self.wcs_type in ("TAN", "TPV", "TNX", "TAN-SIP"):
self.projection = "gnomonic"
elif self.wcs_type in ("STG", "STG-SIP"):
self.projection = "stereographic"
elif self.wcs_type in ("ZEA", "ZEA-SIP"):
self.projection = "lambert"
elif self.wcs_type in ("ARC", "ARC-SIP"):
self.projection = "postel"
else:
raise ValueError("Invalid wcs_type in _data")
# set cdinv and convert to jax
self.cd = jnp.array(self.cd)
self.crpix = jnp.array(self.crpix)
if self.pv is not None:
self.pv = jnp.array(self.pv)
if self.ab is not None:
self.ab = jnp.array(self.ab)
if self.abp is not None:
self.abp = jnp.array(self.abp)
self.cdinv = jnp.linalg.inv(self.cd)
return
# Read the file if given.
if file_name is not None:
if dir is not None:
self._tag = repr(os.path.join(dir, file_name))
else:
self._tag = repr(file_name)
if hdu is not None:
self._tag += ", hdu=%r" % hdu
if compression != "auto":
self._tag += ", compression=%r" % compression
if header is not None:
raise GalSimIncompatibleValuesError(
"Cannot provide both file_name and pyfits header",
file_name=file_name,
header=header,
)
hdu, hdu_list, fin = fits.readFile(file_name, dir, hdu, compression)
try:
if file_name is not None:
header = hdu.header
if header is None:
raise GalSimIncompatibleValuesError(
"Must provide either file_name or header",
file_name=file_name,
header=header,
)
# Read the wcs information from the header.
self._read_header(header)
finally:
if file_name is not None:
fits.closeHDUList(hdu_list, fin)
if origin is not None:
self.crpix += jnp.array([origin.x, origin.y])
[docs]
def tree_flatten(self):
"""This function flattens the WCS into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (
self.crpix,
self.cd,
self.center,
self.pv,
self.ab,
self.abp,
)
# Define auxiliary static data that doesn’t need to be traced
aux_data = (self.wcs_type,)
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
return cls(_data=aux_data + children)
# The origin is a required attribute/property, since it is used by some functions like
# shiftOrigin to get the current origin value. We don't use it in this class, though, so
# just make origin a dummy property that returns 0,0.
@property
@implements(_galsim.GSFitsWCS.origin)
def origin(self):
return PositionD(0.0, 0.0)
def _read_header(self, header):
# Start by reading the basic WCS stuff that most types have.
ctype1 = header.get("CTYPE1", "")
ctype2 = header.get("CTYPE2", "")
if ctype1.startswith("DEC--") and ctype2.startswith("RA---"):
flip = True
elif ctype1.startswith("RA---") and ctype2.startswith("DEC--"):
flip = False
else:
raise GalSimError(
"GSFitsWCS only supports celestial coordinate systems."
"Expecting CTYPE1,2 to start with RA--- and DEC--. Got %s, %s"
% (ctype1, ctype2)
)
if ctype1[5:] != ctype2[5:]: # pragma: no cover
raise OSError("ctype1, ctype2 do not seem to agree on the WCS type")
self.wcs_type = ctype1[5:]
if self.wcs_type in ("TAN", "TPV", "TNX", "TAN-SIP"):
self.projection = "gnomonic"
elif self.wcs_type in ("STG", "STG-SIP"):
self.projection = "stereographic"
elif self.wcs_type in ("ZEA", "ZEA-SIP"):
self.projection = "lambert"
elif self.wcs_type in ("ARC", "ARC-SIP"):
self.projection = "postel"
else:
raise GalSimValueError(
"GSFitsWCS cannot read files using given wcs_type.",
self.wcs_type,
(
"TAN",
"TPV",
"TNX",
"TAN-SIP",
"STG",
"STG-SIP",
"ZEA",
"ZEA-SIP",
"ARC",
"ARC-SIP",
),
)
crval1 = float(header["CRVAL1"])
crval2 = float(header["CRVAL2"])
crpix1 = float(header["CRPIX1"])
crpix2 = float(header["CRPIX2"])
if "CD1_1" in header:
cd11 = float(header["CD1_1"])
cd12 = float(header["CD1_2"])
cd21 = float(header["CD2_1"])
cd22 = float(header["CD2_2"])
elif "CDELT1" in header:
if "PC1_1" in header:
cd11 = float(header["PC1_1"]) * float(header["CDELT1"])
cd12 = float(header["PC1_2"]) * float(header["CDELT1"])
cd21 = float(header["PC2_1"]) * float(header["CDELT2"])
cd22 = float(header["PC2_2"]) * float(header["CDELT2"])
else:
cd11 = float(header["CDELT1"])
cd12 = 0.0
cd21 = 0.0
cd22 = float(header["CDELT2"])
else: # pragma: no cover (all our test files have either CD or CDELT)
cd11 = 1.0
cd12 = 0.0
cd21 = 0.0
cd22 = 1.0
# Usually the units are degrees, but make sure
if "CUNIT1" in header:
cunit1 = header["CUNIT1"]
cunit2 = header["CUNIT2"]
ra_units = AngleUnit.from_name(cunit1)
dec_units = AngleUnit.from_name(cunit2)
else:
ra_units = degrees
dec_units = degrees
if flip:
crval1, crval2 = crval2, crval1
ra_units, dec_units = dec_units, ra_units
cd11, cd21 = cd21, cd11
cd12, cd22 = cd22, cd12
self.crpix = np.array([crpix1, crpix2])
self.cd = np.array([[cd11, cd12], [cd21, cd22]])
self.center = CelestialCoord(crval1 * ra_units, crval2 * dec_units)
# There was an older proposed standard that used TAN with PV values, which is used by
# SCamp, so we want to support it if possible. The standard is now called TPV, so
# use that for our wcs_type if we see the PV values with TAN.
if self.wcs_type == "TAN" and "PV1_1" in header:
self.wcs_type = "TPV"
self.pv = None
self.ab = None
self.abp = None
if self.wcs_type == "TPV":
self._read_tpv(header)
elif self.wcs_type == "TNX":
self._read_tnx(header)
elif self.wcs_type in ("TAN-SIP", "STG-SIP", "ZEA-SIP", "ARC-SIP"):
self._read_sip(header)
# I think the CUNIT specification applies to the CD matrix as well, but I couldn't actually
# find good documentation for this. Plus all the examples I saw used degrees anyway, so
# it's hard to tell. Hopefully this will never matter, but if CUNIT is not deg, this
# next bit might be wrong.
# I did see documentation that the PV matrices always use degrees, so at least we shouldn't
# have to worry about that.
if ra_units != degrees: # pragma: no cover
self.cd[0, :] *= 1.0 * ra_units / degrees
if dec_units != degrees: # pragma: no cover
self.cd[1, :] *= 1.0 * dec_units / degrees
# convert to JAX after reading
self.cd = jnp.array(self.cd)
self.crpix = jnp.array(self.crpix)
self.cdinv = jnp.linalg.inv(self.cd)
def _read_tpv(self, header):
# See http://fits.gsfc.nasa.gov/registry/tpvwcs/tpv.html for details about how
# the TPV standard is defined.
# The standard includes an option to have odd powers of r, which kind of screws
# up the numbering of these coefficients. We don't implement these terms, so
# before going further, check to make sure none are present.
odd_indices = [3, 11, 23, 39]
if any(
(
header.get("PV%s_%s" % (i, j), 0.0) != 0.0
for i in [1, 2]
for j in odd_indices
)
):
raise GalSimNotImplementedError("TPV not implemented for odd powers of r")
pv1 = [
float(header.get("PV1_%s" % k, 0.0))
for k in range(40)
if k not in odd_indices
]
pv2 = [
float(header.get("PV2_%s" % k, 0.0))
for k in range(40)
if k not in odd_indices
]
maxk = max(np.nonzero(pv1)[0][-1], np.nonzero(pv2)[0][-1])
# maxk = (order+1) * (order+2) / 2 - 1
order = int(np.floor(np.sqrt(2 * (maxk + 1)))) - 1
self.pv = np.zeros((2, order + 1, order + 1))
# Another strange thing is that the two matrices are defined in the opposite order
# with respect to their element ordering. But at least now, without the odd terms,
# we can just proceed in order in the k indices. So what we call k=3..9 here were
# originally PVi_4..10.
# For reference, here is what it would look like for order = 3:
# self.pv = np.array( [ [ [ pv1[0], pv1[2], pv1[5], pv1[9] ],
# [ pv1[1], pv1[4], pv1[8], 0. ],
# [ pv1[3], pv1[7], 0. , 0. ],
# [ pv1[6], 0. , 0. , 0. ] ],
# [ [ pv2[0], pv2[1], pv2[3], pv2[6] ],
# [ pv2[2], pv2[4], pv2[7], 0. ],
# [ pv2[5], pv2[8], 0. , 0. ],
# [ pv2[9], 0. , 0. , 0. ] ] ] )
k = 0
for N in range(order + 1):
for j in range(N + 1):
i = N - j
self.pv[0, i, j] = pv1[k]
self.pv[1, j, i] = pv2[k]
k = k + 1
# convert to JAX after reading
self.pv = jnp.array(self.pv)
def _read_sip(self, header):
a_order = int(header["A_ORDER"])
b_order = int(header["B_ORDER"])
order = max(a_order, b_order) # Use the same order for both
a = [
float(header.get("A_" + str(i) + "_" + str(j), 0.0))
for i in range(order + 1)
for j in range(order + 1)
]
a = np.array(a).reshape((order + 1, order + 1))
b = [
float(header.get("B_" + str(i) + "_" + str(j), 0.0))
for i in range(order + 1)
for j in range(order + 1)
]
b = np.array(b).reshape((order + 1, order + 1))
a[1, 0] += (
1 # Standard A,B are a differential calculation. It's more convenient to
)
b[0, 1] += 1 # keep this as an absolute calculation like PV does.
self.ab = np.array([a, b])
# The reverse transformation is not required to be there.
if "AP_ORDER" in header:
ap_order = int(header["AP_ORDER"])
bp_order = int(header["BP_ORDER"])
order = max(ap_order, bp_order) # Use the same order for both
ap = [
float(header.get("AP_" + str(i) + "_" + str(j), 0.0))
for i in range(order + 1)
for j in range(order + 1)
]
ap = np.array(ap).reshape((order + 1, order + 1))
bp = [
float(header.get("BP_" + str(i) + "_" + str(j), 0.0))
for i in range(order + 1)
for j in range(order + 1)
]
bp = np.array(bp).reshape((order + 1, order + 1))
ap[1, 0] += 1
bp[0, 1] += 1
self.abp = np.array([ap, bp])
# convert to JAX after reading
self.abp = jnp.array(self.abp)
# convert to JAX after reading
self.ab = jnp.array(self.ab)
def _read_tnx(self, header):
# TNX has a few different options. Rather than keep things in the native format,
# we actually convert to the equivalent of TPV to make the actual operations faster.
# See http://iraf.noao.edu/projects/ccdmosaic/tnx.html for details.
# First, parse the input values, which are stored in WAT keywords:
k = 1
wat1 = ""
key = "WAT1_%03d" % k
while key in header:
wat1 += header[key]
k = k + 1
key = "WAT1_%03d" % k
wat1 = wat1.split()
k = 1
wat2 = ""
key = "WAT2_%03d" % k
while key in header:
wat2 += header[key]
k = k + 1
key = "WAT2_%03d" % k
wat2 = wat2.split()
if (
len(wat1) < 12
or wat1[0] != "wtype=tnx"
or wat1[1] != "axtype=ra"
or wat1[2] != "lngcor"
or wat1[3] != "="
or not wat1[4].startswith('"')
or not wat1[-1].endswith('"')
): # pragma: no cover
raise GalSimError("TNX WAT1 was not as expected")
if (
len(wat2) < 12
or wat2[0] != "wtype=tnx"
or wat2[1] != "axtype=dec"
or wat2[2] != "latcor"
or wat2[3] != "="
or not wat2[4].startswith('"')
or not wat2[-1].endswith('"')
): # pragma: no cover
raise GalSimError("TNX WAT2 was not as expected")
# Break the next bit out into another function, since it is the same for x and y.
pv1 = self._parse_tnx_data(wat1[4:])
pv2 = self._parse_tnx_data(wat2[4:])
# Those just give the adjustments to the position, not the matrix that gives the final
# position. i.e. the TNX standard uses u = u + [1 u u^2 u^3] PV [1 v v^2 v^3]T.
# So we need to add 1 to the correct term in each matrix to get what we really want.
pv1[1, 0] += 1.0
pv2[0, 1] += 1.0
# Finally, store these as our pv 3-d array.
self.pv = np.array([pv1, pv2])
# We've now converted this to TPV, so call it that when we output to a fits header.
self.wcs_type = "TPV"
# convert to JAX after reading
self.pv = jnp.array(self.pv)
def _parse_tnx_data(self, data):
# I'm not sure if there is any requirement on there being a space before the final " and
# not before the initial ". But both the example in the description of the standard and
# the one we have in our test directory are this way. Here, if the " is by itself, I
# remove the item, and if it is part of a longer string, I just strip it off. Seems the
# most sensible thing to do.
if data[0] == '"': # pragma: no cover
data = data[1:]
else:
data[0] = data[0][1:]
if data[-1] == '"':
data = data[:-1]
else: # pragma: no cover
data[-1] = data[-1][:-1]
code = int(
data[0].strip(".")
) # Weirdly, these integers are given with decimal points.
xorder = int(data[1].strip("."))
yorder = int(data[2].strip("."))
cross = int(data[3].strip("."))
if cross != 2: # pragma: no cover
raise GalSimNotImplementedError(
"TNX only implemented for half-cross option."
)
if xorder != 4 or yorder != 4: # pragma: no cover
raise GalSimNotImplementedError("TNX only implemented for order = 4")
# Note: order = 4 really means cubic. order is how large the pv matrix is, i.e. 4x4.
xmin = float(data[4])
xmax = float(data[5])
ymin = float(data[6])
ymax = float(data[7])
pv1 = [float(x) for x in data[8:]]
if len(pv1) != 10: # pragma: no cover
raise GalSimError("Wrong number of items found in WAT data")
# Put these into our matrix formulation.
pv = np.array(
[
[pv1[0], pv1[4], pv1[7], pv1[9]],
[pv1[1], pv1[5], pv1[8], 0.0],
[pv1[2], pv1[6], 0.0, 0.0],
[pv1[3], 0.0, 0.0, 0.0],
]
)
# Convert from Legendre or Chebyshev polynomials into regular polynomials.
if code < 3: # pragma: no branch (The only test file I can find has code = 1)
# Instead of 1, x, x^2, x^3, Chebyshev uses: 1, x', 2x'^2 - 1, 4x'^3 - 3x
# where x' = (2x - xmin - xmax) / (xmax-xmin).
# Similarly, with y' = (2y - ymin - ymin) / (ymax-ymin)
# We'd like to convert the pv matrix from being in terms of x' and y' to being
# in terms of just x, y. To see how this works, look at what pv[1,1] means:
#
# First, let's say we can write x as (a + bx), and we can write y' as (c + dy).
# Then the term for pv[1,1] is:
#
# term = x' * pv[1,1] * y'
# = (a + bx) * pv[1,1] * (d + ey)
# = a * pv[1,1] * c + a * pv[1,1] * d * y
# + x * b * pv[1,1] * c + x * b * pv[1,1] * d * y
#
# So the single term initially will contribute to 4 different terms in the final
# matrix. And the contributions will just be pv[1,1] times the outer product
# [a b]T [d e]. So if we can determine the matrix that converts from
# [1, x, x^2, x^3] to the Chebyshev vector, the the matrix we want is simply
# xmT pv ym.
a = -(xmax + xmin) / (xmax - xmin)
b = 2.0 / (xmax - xmin)
c = -(ymax + ymin) / (ymax - ymin)
d = 2.0 / (ymax - ymin)
xm = np.zeros((4, 4))
ym = np.zeros((4, 4))
xm[0, 0] = 1.0
xm[1, 0] = a
xm[1, 1] = b
ym[0, 0] = 1.0
ym[1, 0] = c
ym[1, 1] = d
if code == 1:
for m in range(2, 4):
# The recursion rule is Pm = 2 x' Pm-1 - Pm-2
# Pm = 2 a Pm-1 - Pm-2 + x * 2 b Pm-1
xm[m] = 2.0 * a * xm[m - 1] - xm[m - 2]
xm[m, 1:] += 2.0 * b * xm[m - 1, :-1]
ym[m] = 2.0 * c * ym[m - 1] - ym[m - 2]
ym[m, 1:] += 2.0 * d * ym[m - 1, :-1]
else: # pragma: no cover
# code == 2 means Legendre. The same argument applies, but we have a
# different recursion rule.
# WARNING: This branch has not been tested! I don't have any TNX files
# with Legendre functions to test it on. I think it's right, but beware!
for m in range(2, 4):
# The recursion rule is Pm = ((2m-1) x' Pm-1 - (m-1) Pm-2) / m
# Pm = ((2m-1) a Pm-1 - (m-1) Pm-2) / m
# + x * ((2m-1) b Pm-1) / m
xm[m] = (
(2.0 * m - 1.0) * a * xm[m - 1] - (m - 1.0) * xm[m - 2]
) / m
xm[m, 1:] += ((2.0 * m - 1.0) * b * xm[m - 1, :-1]) / m
ym[m] = (
(2.0 * m - 1.0) * c * ym[m - 1] - (m - 1.0) * ym[m - 2]
) / m
ym[m, 1:] += ((2.0 * m - 1.0) * d * ym[m - 1, :-1]) / m
pv2 = np.dot(xm.T, np.dot(pv, ym))
return pv2
def _apply_ab(self, x, y, ab):
# Note: this is used for both pv and ab, since the action is the same.
# They just occur at two different places in the calculation.
x1 = horner2d(x, y, ab[0], triangle=True)
y1 = horner2d(x, y, ab[1], triangle=True)
return x1, y1
def _uv(self, x, y):
# Most of the work for _radec. But stop at (u,v).
# Start with (u,v) = the image position
x = cast_to_float(x)
y = cast_to_float(y)
x -= self.crpix[0]
y -= self.crpix[1]
if self.ab is not None:
x, y = self._apply_ab(x, y, self.ab)
# This converts to (u,v) in the tangent plane
# Expanding this out is a bit faster than using np.dot for 2x2 matrix.
# This is a bit faster than using np.dot for 2x2 matrix.
u = self.cd[0, 0] * x + self.cd[0, 1] * y
v = self.cd[1, 0] * x + self.cd[1, 1] * y
if self.pv is not None:
u, v = self._apply_ab(u, v, self.pv)
# Convert (u,v) from degrees to radians
# Also, the FITS standard defines u,v backwards relative to our standard.
# They have +u increasing to the east, not west. Hence the - for u.
factor = 1.0 * degrees / radians
u *= -factor
v *= factor
return u, v
def _radec(self, x, y, color=None):
# Get the position in the tangent plane
u, v = self._uv(x, y)
# Then convert from (u,v) to (ra, dec) using the appropriate projection.
ra, dec = self.center.deproject_rad(u, v, projection=self.projection)
return ra, dec
def _xy(self, ra, dec, color=None):
u, v = self.center.project_rad(ra, dec, projection=self.projection)
# Again, FITS has +u increasing to the east, not west. Hence the - for u.
factor = radians / degrees
u *= -factor
v *= factor
if self.pv is not None:
u, v = _invert_ab_noraise(u, v, self.pv)
# This is a bit faster than using np.dot for 2x2 matrix.
x = self.cdinv[0, 0] * u + self.cdinv[0, 1] * v
y = self.cdinv[1, 0] * u + self.cdinv[1, 1] * v
if self.ab is not None:
x, y = _invert_ab_noraise(x, y, self.ab, abp=self.abp)
x += self.crpix[0]
y += self.crpix[1]
return x, y
# Override the version in CelestialWCS, since we can do this more efficiently.
def _local(self, image_pos, color=None):
if image_pos is None:
raise TypeError("origin must be a PositionD or PositionI argument")
# The key lemma here is that chain rule for jacobians is just matrix multiplication.
# i.e. if s = s(u,v), t = t(u,v) and u = u(x,y), v = v(x,y), then
# ( dsdx dsdy ) = ( dsdu dudx + dsdv dvdx dsdu dudy + dsdv dvdy )
# ( dtdx dtdy ) = ( dtdu dudx + dtdv dvdx dtdu dudy + dtdv dvdy )
# = ( dsdu dsdv ) ( dudx dudy )
# ( dtdu dtdv ) ( dvdx dvdy )
#
# So if we can find the jacobian for each step of the process, we just multiply the
# jacobians.
#
# We also need to keep track of the position along the way, so we have to repeat many
# of the steps in _radec.
p1 = jnp.array([image_pos.x, image_pos.y], dtype=float)
# Start with unit jacobian
jac = jnp.eye(2)
# No effect on the jacobian from this step.
p1 -= self.crpix
if self.ab is not None:
x = p1[0]
y = p1[1]
order = len(self.ab[0]) - 1
xpow = x ** jnp.arange(order + 1)
ypow = y ** jnp.arange(order + 1)
p1 = jnp.dot(jnp.dot(self.ab, ypow), xpow)
dxpow = jnp.zeros(order + 1)
dypow = jnp.zeros(order + 1)
dxpow = dxpow.at[1:].set((jnp.arange(order) + 1.0) * xpow[:-1])
dypow = dypow.at[1:].set((jnp.arange(order) + 1.0) * ypow[:-1])
j1 = jnp.transpose(
jnp.array(
[
jnp.dot(jnp.dot(self.ab, ypow), dxpow),
jnp.dot(jnp.dot(self.ab, dypow), xpow),
]
)
)
jac = jnp.dot(j1, jac)
# The jacobian here is just the cd matrix.
p2 = jnp.dot(self.cd, p1)
jac = jnp.dot(self.cd, jac)
if self.pv is not None:
# Now we apply the distortion terms
u = p2[0]
v = p2[1]
order = len(self.pv[0]) - 1
upow = u ** jnp.arange(order + 1)
vpow = v ** jnp.arange(order + 1)
p2 = jnp.dot(jnp.dot(self.pv, vpow), upow)
# The columns of the jacobian for this step are the same function with dupow
# or dvpow.
dupow = jnp.zeros(order + 1)
dvpow = jnp.zeros(order + 1)
dupow = dupow.at[1:].set((jnp.arange(order) + 1.0) * upow[:-1])
dvpow = dvpow.at[1:].set((jnp.arange(order) + 1.0) * vpow[:-1])
j1 = jnp.transpose(
jnp.array(
[
jnp.dot(jnp.dot(self.pv, vpow), dupow),
jnp.dot(jnp.dot(self.pv, dvpow), upow),
]
)
)
jac = jnp.dot(j1, jac)
unit_convert = jnp.array([-1 * degrees / radians, 1 * degrees / radians])
p2 *= unit_convert
# Subtle point: Don't use jac *= ..., because jac might currently be self.cd, and
# that would change self.cd!
jac = jac * jnp.transpose(jnp.array([unit_convert]))
# Finally convert from (u,v) to (ra, dec). We have a special function that computes
# the jacobian of this step in the CelestialCoord class.
j2 = self.center.jac_deproject_rad(p2[0], p2[1], projection=self.projection)
jac = jnp.dot(j2, jac)
# This now has units of radians/pixel. We want instead arcsec/pixel.
jac *= radians / arcsec
return JacobianWCS(jac[0, 0], jac[0, 1], jac[1, 0], jac[1, 1])
def _newOrigin(self, origin):
ret = self.copy()
ret.crpix = ret.crpix + jnp.array([origin.x, origin.y])
return ret
def _writeHeader(self, header, bounds):
header["GS_WCS"] = ("GSFitsWCS", "GalSim WCS name")
header["CTYPE1"] = "RA---" + self.wcs_type
header["CTYPE2"] = "DEC--" + self.wcs_type
header["CRPIX1"] = cast_to_python_float(self.crpix[0])
header["CRPIX2"] = cast_to_python_float(self.crpix[1])
header["CD1_1"] = cast_to_python_float(self.cd[0][0])
header["CD1_2"] = cast_to_python_float(self.cd[0][1])
header["CD2_1"] = cast_to_python_float(self.cd[1][0])
header["CD2_2"] = cast_to_python_float(self.cd[1][1])
header["CUNIT1"] = "deg"
header["CUNIT2"] = "deg"
header["CRVAL1"] = cast_to_python_float(self.center.ra / degrees)
header["CRVAL2"] = cast_to_python_float(self.center.dec / degrees)
if self.pv is not None:
order = len(self.pv[0]) - 1
k = 0
odd_indices = [3, 11, 23, 39]
for n in range(order + 1):
for j in range(n + 1):
i = n - j
header["PV1_" + str(k)] = cast_to_python_float(self.pv[0, i, j])
header["PV2_" + str(k)] = cast_to_python_float(self.pv[1, j, i])
k = k + 1
if k in odd_indices:
k = k + 1
if self.ab is not None:
order = len(self.ab[0]) - 1
header["A_ORDER"] = order
for i in range(order + 1):
for j in range(order + 1):
aij = self.ab[0, i, j]
if i == 1 and j == 0:
aij -= 1 # Turn back into standard form.
if aij != 0.0:
header["A_" + str(i) + "_" + str(j)] = cast_to_python_float(aij)
header["B_ORDER"] = order
for i in range(order + 1):
for j in range(order + 1):
bij = self.ab[1, i, j]
if i == 0 and j == 1:
bij -= 1
if bij != 0.0:
header["B_" + str(i) + "_" + str(j)] = cast_to_python_float(bij)
if self.abp is not None:
order = len(self.abp[0]) - 1
header["AP_ORDER"] = order
for i in range(order + 1):
for j in range(order + 1):
apij = self.abp[0, i, j]
if i == 1 and j == 0:
apij -= 1
if apij != 0.0:
header["AP_" + str(i) + "_" + str(j)] = cast_to_python_float(
apij
)
header["BP_ORDER"] = order
for i in range(order + 1):
for j in range(order + 1):
bpij = self.abp[1, i, j]
if i == 0 and j == 1:
bpij -= 1
if bpij != 0.0:
header["BP_" + str(i) + "_" + str(j)] = cast_to_python_float(
bpij
)
return header
@staticmethod
def _readHeader(header):
return GSFitsWCS(header=header)
[docs]
@implements(_galsim.GSFitsWCS.copy)
def copy(self):
# The copy module version of copying the dict works fine here.
return copy.copy(self)
def __eq__(self, other):
return self is other or (
isinstance(other, GSFitsWCS)
and self.wcs_type == other.wcs_type
and jnp.array_equal(self.crpix, other.crpix)
and jnp.array_equal(self.cd, other.cd)
and self.center == other.center
and (
(self.pv is None and other.pv is None)
or jnp.array_equal(self.pv, other.pv)
)
and (
(self.ab is None and other.ab is None)
or jnp.array_equal(self.ab, other.ab)
)
and (
(self.abp is None and other.abp is None)
or jnp.array_equal(self.abp, other.abp)
)
)
def __repr__(self):
pv_repr = repr(ensure_hashable(self.pv))
ab_repr = repr(ensure_hashable(self.ab))
abp_repr = repr(ensure_hashable(self.abp))
return "galsim.GSFitsWCS(_data = [%r, %r, %r, %r, %s, %s, %s])" % (
self.wcs_type,
ensure_hashable(self.crpix),
ensure_hashable(self.cd),
self.center,
pv_repr,
ab_repr,
abp_repr,
)
def __str__(self):
if self._tag is None:
return self.__repr__()
else:
return "galsim.GSFitsWCS(%s)" % (self._tag)
def __hash__(self):
return hash(repr(self))
[docs]
@implements(_galsim.fitswcs.TanWCS)
def TanWCS(affine, world_origin, units=arcsec):
# These will raise the appropriate errors if affine is not the right type.
dudx = affine.dudx * units / degrees
dudy = affine.dudy * units / degrees
dvdx = affine.dvdx * units / degrees
dvdy = affine.dvdy * units / degrees
origin = affine.origin
# The - signs are because the Fits standard is in terms of +u going east, rather than west
# as we have defined. So just switch the sign in the CD matrix.
cd = jnp.array([[-dudx, -dudy], [dvdx, dvdy]], dtype=float)
crpix = jnp.array([origin.x, origin.y], dtype=float)
# We also need to absorb the affine world_origin back into crpix, since GSFits is expecting
# crpix to be the location of the tangent point in image coordinates. i.e. where (u,v) = (0,0)
# (u,v) = CD * (x-x0,y-y0) + (u0,v0)
# (0,0) = CD * (x0',y0') - CD * (x0,y0) + (u0,v0)
# CD (x0',y0') = CD (x0,y0) - (u0,v0)
# (x0',y0') = (x0,y0) - CD^-1 (u0,v0)
uv = jnp.array(
[
affine.world_origin.x * units / degrees,
affine.world_origin.y * units / degrees,
]
)
crpix -= jnp.dot(jnp.linalg.inv(cd), uv)
# Invoke the private constructor of GSFits using the _data kwarg.
data = ("TAN", crpix, cd, world_origin, None, None, None)
return GSFitsWCS(_data=data)
# This is a list of all the WCS types that can potentially read a WCS from a FITS file.
# The function FitsWCS will try each of these in order and return the first one that
# succeeds. AffineTransform should be last, since it will always succeed.
# The list is defined here at global scope so that external modules can add extra
# WCS types to the list if desired.
fits_wcs_types = [
GSFitsWCS, # This doesn't work for very many WCS types, but it works for the very common
# TAN projection, and also TPV, which is used by SCamp. If it does work, it
# is a good choice, since it is easily the fastest of any of these.
]
[docs]
@implements(
_galsim.fitswcs.FitsWCS,
lax_description="JAX-GalSim only supports the GSFitsWCS class for celestial WCS types.",
)
def FitsWCS(
file_name=None,
dir=None,
hdu=None,
header=None,
compression="auto",
text_file=False,
suppress_warning=False,
):
if file_name is not None:
if header is not None:
raise GalSimIncompatibleValuesError(
"Cannot provide both file_name and pyfits header",
file_name=file_name,
header=header,
)
header = fits.FitsHeader(
file_name=file_name,
dir=dir,
hdu=hdu,
compression=compression,
text_file=text_file,
)
else:
file_name = "header" # For sensible error messages below.
if header is None:
raise GalSimIncompatibleValuesError(
"Must provide either file_name or header",
file_name=file_name,
header=header,
)
if not isinstance(header, fits.FitsHeader):
header = fits.FitsHeader(header)
if "CTYPE1" not in header and "CDELT1" not in header:
if not suppress_warning:
galsim_warn(
"No WCS information found in %r. Defaulting to PixelScale(1.0)"
% (file_name)
)
return PixelScale(1.0)
# For linear WCS specifications, AffineTransformation should work.
# Note: Most files will have CTYPE1,2, but old style with only CDELT1,2 sometimes omits it.
if header.get("CTYPE1", "LINEAR") == "LINEAR":
wcs = AffineTransform._readHeader(header)
# Convert to PixelScale if possible.
# TODO: Should we do this check in JAX-GalSim or maybe always return AffineTransform?
if wcs.dudx == wcs.dvdy and wcs.dudy == wcs.dvdx == 0:
if wcs.x0 == wcs.y0 == wcs.u0 == wcs.v0 == 0:
wcs = PixelScale(wcs.dudx)
else:
wcs = OffsetWCS(wcs.dudx, wcs.origin, wcs.world_origin)
return wcs
# Otherwise (and typically), try the various wcs types that can read celestial coordinates.
for wcs_type in fits_wcs_types:
try:
wcs = wcs_type._readHeader(header)
# Give it a better tag for the repr if appropriate.
if hasattr(wcs, "_tag") and file_name != "header":
if dir is not None:
wcs._tag = repr(os.path.join(dir, file_name))
else:
wcs._tag = repr(file_name)
if hdu is not None:
wcs._tag += ", hdu=%r" % hdu
if compression != "auto":
wcs._tag += ", compression=%r" % compression
return wcs
except Exception:
pass
else:
# Finally, this one is really the last resort, since it only reads in the linear part of the
# WCS. It defaults to the equivalent of a pixel scale of 1.0 if even these are not present.
if not suppress_warning:
galsim_warn(
"All the fits WCS types failed to read %r. Using AffineTransform "
"instead, which will not really be correct." % (file_name)
)
return AffineTransform._readHeader(header)
# Let this function work like a class in config.
FitsWCS._req_params = {"file_name": str}
FitsWCS._opt_params = {"dir": str, "hdu": int, "compression": str, "text_file": bool}
@jax.jit
def _invert_ab_noraise_loop_body(
x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
):
# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
dx = -(du * dvdy - dv * dudy) / det
dy = -(-du * dvdx + dv * dudx) / det
x += dx
y += dy
return x, y, dx, dy
@jax.jit
def _invert_ab_noraise(u, v, ab, abp=None):
# get guess from abp if we have it
if abp is None:
x = u.copy()
y = v.copy()
else:
x = horner2d(u, v, abp[0])
y = horner2d(u, v, abp[1])
# Code below from galsim C++ layer and written by Josh Meyers
# Matt Becker translated into jax
# Assemble horner2d coefs for derivatives
nab = ab.shape[1]
dudxcoef = (jnp.arange(nab)[:, None] * ab[0])[1:, :-1]
dudycoef = (jnp.arange(nab) * ab[0])[:-1, 1:]
dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1]
dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:]
def _step(i, args):
x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args
# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
duu = -(du * dvdy - dv * dudy) / det
dvv = -(-du * dvdx + dv * dudx) / det
x += duu
y += dvv
return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
x, y, dx, dy = jax.lax.fori_loop(
0,
10,
_step,
(
x,
y,
jnp.zeros_like(x),
jnp.zeros_like(y),
u,
v,
ab,
dudxcoef,
dudycoef,
dvdxcoef,
dvdycoef,
),
unroll=True,
)[0:4]
x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
lambda x, y: (x * jnp.nan, y * jnp.nan),
lambda x, y: (x, y),
x,
y,
)
return x, y