Source code for jax_galsim.shear

import galsim as _galsim
import jax.numpy as jnp
from galsim.errors import GalSimIncompatibleValuesError
from jax.tree_util import register_pytree_node_class

from jax_galsim.angle import Angle, _Angle, radians
from jax_galsim.core.utils import ensure_hashable, implements


[docs] @register_pytree_node_class @implements( _galsim.Shear, lax_description="""\ The jax_galsim implementation of ``Shear`` does not perform range checking of the \ shear (e.g., ``|g| <= 1``) upon construction.""", ) class Shear(object): def __init__(self, *args, **kwargs): # There is no valid set of >2 keyword arguments, so raise an exception in this case: if len(kwargs) > 2: raise TypeError( "Shear constructor received >2 keyword arguments: %s" % kwargs.keys() ) if len(args) > 1: raise TypeError( "Shear constructor received >1 non-keyword arguments: %s" % args ) # If a component of e, g, or eta, then require that the other component is zero if not set, # and don't allow specification of mixed pairs like e1 and g2. # Also, require a position angle if we didn't get g1/g2, e1/e2, or eta1/eta2 # Unnamed arg must be a complex shear if len(args) == 1: self._g = args[0] if not (jnp.all(jnp.iscomplex(self._g)) or isinstance(self._g, complex)): raise TypeError( "Non-keyword argument to Shear must be complex g1 + 1j * g2" ) # Empty constructor means shear == (0,0) elif not kwargs: self._g = 0j # g1,g2 elif "g1" in kwargs or "g2" in kwargs: g1 = kwargs.pop("g1", 0.0) g2 = kwargs.pop("g2", 0.0) self._g = g1 + 1j * g2 # e1,e2 elif "e1" in kwargs or "e2" in kwargs: e1 = kwargs.pop("e1", 0.0) e2 = kwargs.pop("e2", 0.0) absesq = e1**2 + e2**2 self._g = (e1 + 1j * e2) * self._e2g(absesq) # eta1,eta2 elif "eta1" in kwargs or "eta2" in kwargs: eta1 = kwargs.pop("eta1", 0.0) eta2 = kwargs.pop("eta2", 0.0) eta = eta1 + 1j * eta2 abseta = abs(eta) self._g = eta * self._eta2g(abseta) # g,beta elif "g" in kwargs: if "beta" not in kwargs: raise GalSimIncompatibleValuesError( "Shear constructor requires beta when g is specified.", g=kwargs["g"], beta=None, ) beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") g = kwargs.pop("g") self._g = g * jnp.exp(2j * beta.rad) # e,beta elif "e" in kwargs: if "beta" not in kwargs: raise GalSimIncompatibleValuesError( "Shear constructor requires beta when e is specified.", e=kwargs["e"], beta=None, ) beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") e = kwargs.pop("e") self._g = self._e2g(e**2) * e * jnp.exp(2j * beta.rad) # eta,beta elif "eta" in kwargs: if "beta" not in kwargs: raise GalSimIncompatibleValuesError( "Shear constructor requires beta when eta is specified.", eta=kwargs["eta"], beta=None, ) beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") eta = kwargs.pop("eta") self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) # q,beta elif "q" in kwargs: if "beta" not in kwargs: raise GalSimIncompatibleValuesError( "Shear constructor requires beta when q is specified.", q=kwargs["q"], beta=None, ) beta = kwargs.pop("beta") if not isinstance(beta, Angle): raise TypeError("beta must be an Angle instance.") q = kwargs.pop("q") eta = -jnp.log(q) self._g = self._eta2g(eta) * eta * jnp.exp(2j * beta.rad) elif "beta" in kwargs: raise GalSimIncompatibleValuesError( "beta provided to Shear constructor, but not g/e/eta/q", beta=kwargs["beta"], e=None, g=None, q=None, eta=None, ) # check for the case where there are 1 or 2 kwargs that are not valid ones for # initializing a Shear if kwargs: raise TypeError( "Shear constructor got unexpected extra argument(s): %s" % kwargs.keys() ) @property @implements(_galsim.Shear.g1) def g1(self): return self._g.real @property @implements(_galsim.Shear.g2) def g2(self): return self._g.imag @property @implements(_galsim.Shear.g) def g(self): return jnp.abs(self._g) @property @implements(_galsim.Shear.beta) def beta(self): return _Angle(0.5 * jnp.angle(self._g)) @property @implements(_galsim.Shear.shear) def shear(self): return self._g @property @implements(_galsim.Shear.e1) def e1(self): return self._g.real * self._g2e(self.g**2) @property @implements(_galsim.Shear.e2) def e2(self): return self._g.imag * self._g2e(self.g**2) @property @implements(_galsim.Shear.e) def e(self): return self.g * self._g2e(self.g**2) @property @implements(_galsim.Shear.esq) def esq(self): return self.e**2 @property @implements(_galsim.Shear.eta1) def eta1(self): return self._g.real * self._g2eta(self.g) @property @implements(_galsim.Shear.eta2) def eta2(self): return self._g.imag * self._g2eta(self.g) @property @implements(_galsim.Shear.eta) def eta(self): return self.g * self._g2eta(self.g) @property @implements(_galsim.Shear.q) def q(self): return (1.0 - self.g) / (1.0 + self.g) # Helpers to convert between different conventions # Note: These return the scale factor by which to multiply. Not the final value. def _g2e(self, absgsq): return 2.0 / (1.0 + absgsq) def _e2g(self, absesq): return jnp.where( absesq > 1.0e-4, 1.0 / (1.0 + jnp.sqrt(1.0 - absesq)), # Avoid numerical issues near e=0 using Taylor expansion 0.5 + absesq * (0.125 + absesq * (0.0625 + absesq * 0.0390625)), ) def _g2eta(self, absg): absgsq = absg * absg return jnp.where( absg > 1.0e-4, 2.0 * jnp.arctanh(absg) / absg, # This doesn't have as much trouble with accuracy, but have to avoid absg=0, # so might as well Taylor expand for small values. 2.0 + absgsq * ((2.0 / 3.0) + absgsq * 0.4), ) def _eta2g(self, abseta): absetasq = abseta * abseta return jnp.where( abseta > 1.0e-4, jnp.tanh(0.5 * abseta) / abseta, 0.5 + absetasq * ((-1.0 / 24.0) + absetasq * (1.0 / 240.0)), ) # define all the various operators on Shear objects def __neg__(self): return _Shear(-self._g) # order of operations: shear by other._shear, then by self._shear def __add__(self, other): return _Shear((self._g + other._g) / (1.0 + self._g.conjugate() * other._g)) # order of operations: shear by -other._shear, then by self._shear def __sub__(self, other): return self + (-other) def __eq__(self, other): return self is other or (isinstance(other, Shear) and self._g == other._g) def __ne__(self, other): return not self.__eq__(other)
[docs] @implements(_galsim.Shear.getMatrix) def getMatrix(self): return jnp.array( [[1.0 + self.g1, self.g2], [self.g2, 1.0 - self.g1]] ) / jnp.sqrt(1.0 - self.g**2)
[docs] @implements(_galsim.Shear.rotationWith) def rotationWith(self, other): # Save a little time by only working on the first column. S3 = self.getMatrix().dot(other.getMatrix()[:, :1]) R = (-(self + other)).getMatrix().dot(S3) theta = jnp.arctan2(R[1, 0], R[0, 0]) return theta * radians
def __repr__(self): return "galsim.Shear(%r)" % (ensure_hashable(self.shear)) def __str__(self): return "galsim.Shear(g1=%s,g2=%s)" % ( ensure_hashable(self.g1), ensure_hashable(self.g2), ) def __hash__(self): return hash(ensure_hashable(self._g))
[docs] def tree_flatten(self): children = (self._g,) return (children, None)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" del aux_data # unused in this class obj = cls.__new__(cls) obj._g = children[0] return obj
[docs] @classmethod def from_galsim(cls, galsim_shear): """Create a jax_galsim `Shear` from a `galsim.Shear` object.""" return cls(g1=galsim_shear.g1, g2=galsim_shear.g2)
[docs] def to_galsim(self): """Create a galsim `Shear` from a `jax_galsim.Shear` object.""" return _galsim.Shear(g1=float(self.g1), g2=float(self.g2))
@implements(_galsim._Shear) def _Shear(shear): ret = Shear.__new__(Shear) ret._g = shear return ret