Source code for jax_galsim.transform
import galsim as _galsim
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_galsim.core.utils import (
compute_major_minor_from_jacobian,
ensure_hashable,
implements,
)
from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.position import PositionD
[docs]
@implements(
_galsim.Transform,
lax_description="Does not support Chromatic Objects or Convolutions.",
)
def Transform(
obj,
jac=(1.0, 0.0, 0.0, 1.0),
offset=PositionD(0.0, 0.0),
flux_ratio=1.0,
gsparams=None,
propagate_gsparams=True,
):
if not (isinstance(obj, GSObject)):
raise TypeError("Argument to Transform must be a GSObject.")
elif (
hasattr(jac, "__call__")
or hasattr(offset, "__call__")
or hasattr(flux_ratio, "__call__")
):
raise NotImplementedError("Transform does not support callable arguments.")
else:
return Transformation(
obj, jac, offset, flux_ratio, gsparams, propagate_gsparams
)
[docs]
@implements(_galsim.Transformation)
@register_pytree_node_class
class Transformation(GSObject):
def __init__(
self,
obj,
jac=(1.0, 0.0, 0.0, 1.0),
offset=PositionD(0.0, 0.0),
flux_ratio=1.0,
gsparams=None,
propagate_gsparams=True,
):
self._gsparams = GSParams.check(gsparams, obj.gsparams)
self._propagate_gsparams = propagate_gsparams
if self._propagate_gsparams:
obj = obj.withGSParams(self._gsparams)
if jac is None:
jac = jnp.array([1.0, 0.0, 0.0, 1.0])
self._params = {
"jac": jac,
"offset": PositionD(offset),
"flux_ratio": flux_ratio,
}
# this import is here to avoid circular imports
# we do not want to mess with the transform properties of the interpolated image
from .interpolatedimage import InterpolatedImage
if isinstance(obj, Transformation) and not isinstance(obj, InterpolatedImage):
# Combine the two affine transformations into one.
dx, dy = self._fwd(obj._params["offset"].x, obj._params["offset"].y)
self._offset.x += dx
self._offset.y += dy
self._params["jac"] = self._jac.dot(obj._jac)
self._params["flux_ratio"] *= obj._flux_ratio
self._original = obj._original
else:
self._original = obj
##############################################################
# The internal code of the methods of the Transform class
# should only aceess _offset, _flux_ratio, and _jac. It
# should not pull these directly from _params.
# Things are structured this way since the interpolated image
# class inherits and overrides these methods.
@property
def _offset(self):
return self._params["offset"]
# we use this property so that the interpolated image can override
# how flux ratio is computer / stored
@property
def _flux_ratio(self):
return self._params["flux_ratio"]
@property
def _jac(self):
jac = self._params["jac"]
return jnp.asarray(
jnp.broadcast_to(jnp.array(jac, dtype=float).ravel(), (4,)),
dtype=float,
).reshape((2, 2))
@property
@implements(_galsim.transform.Transformation.original)
def original(self):
return self._original
@property
@implements(_galsim.transform.Transformation.jac)
def jac(self):
return self._jac
@property
@implements(_galsim.transform.Transformation.offset)
def offset(self):
return self._offset
@property
@implements(_galsim.transform.Transformation.flux_ratio)
def flux_ratio(self):
return self._flux_ratio
@property
def _flux(self):
return self._flux_scaling * self._original.flux
[docs]
@implements(_galsim.Transformation.withGSParams)
def withGSParams(self, gsparams=None, **kwargs):
"""Create a version of the current object with the given gsparams
.. note::
Unless you set ``propagate_gsparams=False``, this method will also update the gsparams
of the wrapped component object.
"""
if gsparams == self._gsparams:
return self
chld, aux = self.tree_flatten()
aux["gsparams"] = GSParams.check(gsparams, self._gsparams, **kwargs)
if self._propagate_gsparams:
new_obj = chld[0].withGSParams(aux["gsparams"])
chld = (new_obj,) + chld[1:]
return self.tree_unflatten(aux, chld)
def __eq__(self, other):
return self is other or (
isinstance(other, Transformation)
and self._original == other._original
and jnp.array_equal(self._jac, other._jac)
and self._offset == other._params["offset"]
and self._flux_ratio == other._flux_ratio
and self._gsparams == other._gsparams
and self._propagate_gsparams == other._propagate_gsparams
)
def __hash__(self):
return hash(
(
"galsim.Transformation",
self._original,
ensure_hashable(self._jac.ravel()),
ensure_hashable(self._offset.x),
ensure_hashable(self._offset.y),
ensure_hashable(self._flux_ratio),
self._gsparams,
self._propagate_gsparams,
)
)
def __repr__(self):
return (
"galsim.Transformation(%r, jac=%r, offset=%r, flux_ratio=%r, gsparams=%r, "
"propagate_gsparams=%r)"
) % (
self._original,
ensure_hashable(self._jac.ravel()),
self._offset,
ensure_hashable(self._flux_ratio),
self._gsparams,
self._propagate_gsparams,
)
@classmethod
def _str_from_jac(cls, jac):
from jax_galsim.wcs import JacobianWCS
dudx, dudy, dvdx, dvdy = jac.ravel()
if dudx != 1 or dudy != 0 or dvdx != 0 or dvdy != 1:
# Figure out the shear/rotate/dilate calls that are equivalent.
jac = JacobianWCS(dudx, dudy, dvdx, dvdy)
scale, shear, theta, flip = jac.getDecomposition()
s = None
if flip:
s = 0 # Special value indicating to just use transform.
if abs(theta.rad) > 1.0e-12:
if s is None:
s = ".rotate(%s)" % theta
else:
s = 0
if shear.g > 1.0e-12:
if s is None:
s = ".shear(%s)" % shear
else:
s = 0
if abs(scale - 1.0) > 1.0e-12:
if s is None:
s = ".expand(%s)" % scale
else:
s = 0
if s == 0:
# If flip or there are two components, then revert to transform as simpler.
s = ".transform(%s,%s,%s,%s)" % (dudx, dudy, dvdx, dvdy)
if s is None:
# If nothing is large enough to show up above, give full detail of transform
s = ".transform(%r,%r,%r,%r)" % (dudx, dudy, dvdx, dvdy)
return s
else:
return ""
def __str__(self):
s = str(self._original)
s += self._str_from_jac(self._jac)
if self._offset.x != 0 or self._offset.y != 0:
s += ".shift(%s,%s)" % (
ensure_hashable(self._offset.x),
ensure_hashable(self._offset.y),
)
if self._flux_ratio != 1.0:
s += " * %s" % (ensure_hashable(self._flux_ratio),)
return s
@property
def _det(self):
return jnp.linalg.det(self._jac)
@property
def _invdet(self):
return 1.0 / self._det
@property
def _invjac(self):
return jnp.linalg.inv(self._jac)
# To avoid confusion with the flux vs amplitude scaling, we use these names below, rather
# than flux_ratio, which is really an amplitude scaling.
@property
def _amp_scaling(self):
return self._flux_ratio
@property
def _flux_scaling(self):
return jnp.abs(self._det) * self._flux_ratio
def _fwd(self, x, y):
res = jnp.dot(self._jac, jnp.array([x, y]))
return res[0], res[1]
def _fwdT(self, x, y):
res = jnp.dot(self._jac.T, jnp.array([x, y]))
return res[0], res[1]
def _inv(self, x, y):
res = jnp.dot(self._invjac, jnp.array([x, y]))
return res[0], res[1]
def _kfactor(self, kx, ky):
kx *= -1j * self._offset.x
ky *= -1j * self._offset.y
kx += ky
return self._flux_scaling * jnp.exp(kx)
@property
def _maxk(self):
_, minor = compute_major_minor_from_jacobian(self._jac)
return self._original.maxk / minor
@property
def _stepk(self):
major, _ = compute_major_minor_from_jacobian(self._jac)
stepk = self._original.stepk / major
# If we have a shift, we need to further modify stepk
# stepk = Pi/R
# R <- R + |shift|
# stepk <- Pi/(Pi/stepk + |shift|)
dr = jnp.hypot(self._offset.x, self._offset.y)
stepk = jnp.pi / (jnp.pi / stepk + dr)
return stepk
@property
def _has_hard_edges(self):
return self._original.has_hard_edges
@property
def _is_axisymmetric(self):
return bool(
self._original.is_axisymmetric
and self._jac[0, 0] == self._jac[1, 1]
and self._jac[0, 1] == -self._jac[1, 0]
and self._offset == PositionD(0.0, 0.0)
)
@property
def _is_analytic_x(self):
return self._original.is_analytic_x
@property
def _is_analytic_k(self):
return self._original.is_analytic_k
@property
def _centroid(self):
cen = self._original.centroid
cen = PositionD(self._fwd(cen.x, cen.y))
cen += self._offset
return cen
@property
def _positive_flux(self):
return self._flux_scaling * self._original.positive_flux
@property
def _negative_flux(self):
return self._flux_scaling * self._original.negative_flux
@property
def _flux_per_photon(self):
return self._calculate_flux_per_photon()
@property
def _max_sb(self):
return self._amp_scaling * self._original.max_sb
def _xValue(self, pos):
pos -= self._offset
inv_pos = PositionD(self._inv(pos.x, pos.y))
return self._original._xValue(inv_pos) * self._amp_scaling
def _kValue(self, kpos):
fwdT_kpos = PositionD(self._fwdT(kpos.x, kpos.y))
return self._original._kValue(fwdT_kpos) * self._kfactor(kpos.x, kpos.y)
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
dx, dy = offset
if jac is not None:
x1 = jac.dot(self._offset._array)
dx += x1[0]
dy += x1[1]
else:
dx += self._offset.x
dy += self._offset.y
flux_scaling *= self._flux_scaling
jac = (
self._jac
if jac is None
else jac
if self._jac is None
else jac.dot(self._jac)
)
return self._original._drawReal(image, jac, (dx, dy), flux_scaling)
def _drawKImage(self, image, jac=None):
from jax_galsim.core.draw import apply_kImage_phases
jac1 = (
self._jac
if jac is None
else jac
if self._jac is None
else jac.dot(self._jac)
)
image = self._original._drawKImage(image, jac1)
_jac = jnp.eye(2) if jac is None else jac
image = apply_kImage_phases(self._offset, image, _jac)
image = image * self._flux_scaling
return image
@implements(_galsim.Transformation._shoot)
def _shoot(self, photons, rng):
self._original._shoot(photons, rng)
photons.x, photons.y = self._fwd(photons.x, photons.y)
photons.x += self._offset.x
photons.y += self._offset.y
photons.scaleFlux(self._flux_scaling)
[docs]
def tree_flatten(self):
"""This function flattens the Transform into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self._original, self._params)
# Define auxiliary static data that doesn’t need to be traced
aux_data = {
"gsparams": self._gsparams,
"propagate_gsparams": self._propagate_gsparams,
}
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
obj = cls.__new__(cls)
obj._gsparams = aux_data["gsparams"]
obj._propagate_gsparams = aux_data["propagate_gsparams"]
obj._original, obj._params = children
return obj
def _Transform(
obj,
jac=(1.0, 0.0, 0.0, 1.0),
offset=PositionD(0.0, 0.0),
flux_ratio=1.0,
gsparams=None,
):
"""Approximately equivalent to Transform, but without some of the sanity checks (such as
checking for chromatic options).
For a `ChromaticObject`, you must use the regular `Transform`.
"""
return Transformation(obj, jac, offset, flux_ratio, gsparams)