Source code for jax_galsim.exponential

from functools import lru_cache

import galsim as _galsim
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate


[docs] @implements(_galsim.Exponential) @register_pytree_node_class class Exponential(GSObject): # The half-light-radius is not analytic, but can be calculated numerically # by iterative solution of equation: # (re / r0) = ln[(re / r0) + 1] + ln(2) _hlr_factor = 1.6783469900166605 _one_third = 1.0 / 3.0 _inv_twopi = 0.15915494309189535 _has_hard_edges = False _is_axisymmetric = True _is_analytic_x = True _is_analytic_k = True def __init__( self, half_light_radius=None, scale_radius=None, flux=1.0, gsparams=None ): if half_light_radius is not None: if scale_radius is not None: raise _galsim.GalSimIncompatibleValuesError( "Only one of scale_radius and half_light_radius may be specified", half_light_radius=half_light_radius, scale_radius=scale_radius, ) else: super().__init__( scale_radius=half_light_radius / Exponential._hlr_factor, flux=flux, gsparams=gsparams, ) elif scale_radius is None: raise _galsim.GalSimIncompatibleValuesError( "Either scale_radius or half_light_radius must be specified", half_light_radius=half_light_radius, scale_radius=scale_radius, ) else: super().__init__(scale_radius=scale_radius, flux=flux, gsparams=gsparams) @property @implements(_galsim.Exponential.scale_radius) def scale_radius(self): return self.params["scale_radius"] @property def _r0(self): return self.scale_radius @property def _inv_r0(self): return 1.0 / self._r0 @property def _norm(self): return self.flux * Exponential._inv_twopi * self._inv_r0**2 @property @implements(_galsim.Exponential.half_light_radius) def half_light_radius(self): return self.params["scale_radius"] * Exponential._hlr_factor def __hash__(self): return hash( ( "galsim.Exponential", ensure_hashable(self.scale_radius), ensure_hashable(self.flux), self.gsparams, ) ) def __repr__(self): return "galsim.Exponential(scale_radius=%r, flux=%r, gsparams=%r)" % ( ensure_hashable(self.scale_radius), ensure_hashable(self.flux), self.gsparams, ) def __str__(self): s = "galsim.Exponential(scale_radius=%s" % (ensure_hashable(self.scale_radius),) s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s @property def _maxk(self): _maxk = self.gsparams.maxk_threshold**-Exponential._one_third return _maxk / self.scale_radius @property def _stepk(self): # The content of this function is inherited from the GalSim C++ layer # https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/SBExponential.cpp#L530 # https://github.com/GalSim-developers/GalSim/blob/ece3bd32c1ae6ed771f2b489c5ab1b25729e0ea4/src/SBExponential.cpp#L97 # Calculate stepk: # int( exp(-r) r, r=0..R) = (1 - exp(-R) - Rexp(-R)) # Fraction excluded is thus (1+R) exp(-R) # A fast solution to (1+R)exp(-R) = x: # log(1+R) - R = log(x) # R = log(1+R) - log(x) logx = jnp.log(self.gsparams.folding_threshold) R = -logx for i in range(3): R = jnp.log(1.0 + R) - logx # Make sure it is at least 5 hlr # half-light radius = 1.6783469900166605 * r0 hlr = 1.6783469900166605 R = jnp.max(jnp.array([R, self.gsparams.stepk_minimum_hlr * hlr])) return jnp.pi / R * self._inv_r0 @property def _max_sb(self): return self._norm def _xValue(self, pos): r = jnp.sqrt(pos.x**2 + pos.y**2) return self._norm * jnp.exp(-r * self._inv_r0) def _kValue(self, kpos): ksqp1 = (kpos.x**2 + kpos.y**2) * self._r0**2 + 1.0 return self.flux / (ksqp1 * jnp.sqrt(ksqp1)) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): _jac = jnp.eye(2) if jac is None else jac return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac)
[docs] @implements(_galsim.Exponential.withFlux) def withFlux(self, flux): return Exponential( scale_radius=self.scale_radius, flux=flux, gsparams=self.gsparams )
@implements(_galsim.Exponential._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) u = ud.generate( photons.x ) # this does not fill arrays like in galsim so is safe _u_cdf, _cdf = _shoot_cdf(self.gsparams.shoot_accuracy) # this interpolation inverts the CDF u = jnp.interp(u, _cdf, _u_cdf) # this converts from u (see above) to r and scales by the actual size of # the object r0. r = -jnp.log(1.0 - u) * self._r0 ang = ( ud.generate(photons.x) * 2.0 * jnp.pi ) # this does not fill arrays like in galsim so is safe photons.x = r * jnp.cos(ang) photons.y = r * jnp.sin(ang) photons.flux = self.flux / photons.size()
@lru_cache(maxsize=8) def _shoot_cdf(shoot_accuracy): """This routine produces a CPU-side cache of the CDF that is embedded into JIT-compiled code as needed.""" # Comments on the math here: # # We are looking to draw from a distribution that is r * exp(-r). # This distribution is the radial PDF of an Exponential profile. # The factor of r comes from the area element r * dr. # # We can compute the CDF of this distribution analytically, but we cannot # invert the CDF in closed form. Thus we invert it numerically using a table. # # One final detail is that we want the inversion to be accurate and are using # linear interpolation. Thus we use a change of variables r = -ln(1 - u) # to make the CDF more linear and map it's domain to [0, 1) instead of [0, inf). # # Putting this all together, we get # # r * exp(-r) dr = -ln(1-u) (1-u) dr/du du # = -ln(1-u) (1-u) * 1 / (1-u) # = -ln(1-u) # # The new range of integration is u = 0 to u = 1. Thus the CDF is # # CDF = -int_0^u ln(1-u') du' # = u - (u - 1) ln(1 - u) # # The final detail is that galsim defines a shoot accuracy and draws photons # between r = 0 and rmax = -log(shoot_accuracy). Thus we normalize the CDF to # its value at umax = 1 - exp(-rmax) and then finally invert the CDF numerically. _rmax = -np.log(shoot_accuracy) _umax = 1.0 - np.exp(-_rmax) _u_cdf = np.linspace(0, _umax, 10000) _cdf = _u_cdf - (_u_cdf - 1) * np.log(1 - _u_cdf) _cdf /= _cdf[-1] return _u_cdf, _cdf