Source code for jax_galsim.spergel

import galsim as _galsim
import jax
import jax.numpy as jnp
from jax.tree_util import Partial as partial
from jax.tree_util import register_pytree_node_class

from jax_galsim.bessel import kv
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import bisect_for_root, ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate


@jax.jit
def gamma(x):
    """Gamma(x)"""
    x = x * 1.0
    return jnp.exp(jax.lax.lgamma(x))


@jax.jit
def _gamma(nu):
    """Gamma(nu) with care for integer nu in [0,5]"""
    return jnp.select(
        [nu == 0, nu == 1, nu == 2, nu == 3, nu == 4, nu == 5],
        [jnp.inf, 1.0, 1.0, 2.0, 6.0, 24.0],
        default=gamma(nu),
    )


@jax.jit
def _gammap1(nu):
    """Gamma(nu+1)"""
    return _gamma(nu + 1.0)


@jax.jit
def z2lz(z):
    """return z^2 * log(z)"""
    return jnp.where(z <= 1e-40, 0.0, z * z * jnp.log(z))


@jax.jit
def f0(z):
    """K_0[z] with z -> 0  O(z^4)"""
    z2 = z * z
    z4 = z2 * z2
    c0 = 0.11593151565841244881
    c1 = 0.27898287891460311220
    c2 = 0.025248929932162694513
    return c0 + c1 * z2 + c2 * z4 - jnp.power(1.0 + 0.125 * z2, 2.0) * jnp.log(z)


@jax.jit
def f1(z):
    """z^1 K_1[z] with z -> 0  O(z^4)"""
    z2 = z * z
    z4 = z2 * z2
    c0 = z2lz(z)  # z^2 log(z)
    c1 = 0.30796575782920622441
    c2 = 0.08537071972865077805
    return 1.0 - c1 * z2 - c2 * z4 + c0 * (0.5 + 0.0625 * z2)


@jax.jit
def f2(z):
    """z^2 K_2[z] with z -> 0  O(z^4)"""
    c1 = 0.10824143945730155610
    z2 = z * z
    z4 = z2 * z2
    c0 = z2lz(z) * z2  # z^4*log(z)
    return 2.0 - 0.5 * z2 + c1 * z4 - 0.125 * c0


@jax.jit
def f3(z):
    """z^3 K_3[z] with z -> 0  O(z^4)"""
    z2 = z * z
    z4 = z2 * z2
    return 8.0 - z2 + 0.125 * z4


@jax.jit
def f4(z):
    """z^4 K_4[z] with z -> 0 O(z^4)"""
    z2 = z * z
    z4 = z2 * z2
    return 48.0 - 4 * z2 + 0.25 * z4


@jax.jit
def f5(z):
    """z^5 K_5[z] with z -> 0 O(z^4)"""
    z2 = z * z
    z4 = z2 * z2
    return 384.0 - 24.0 * z2 + z4


@jax.jit
def fsmallz_nu(z, nu):
    def fnu(z, nu):
        """z^nu K_nu[z] with z -> 0 O(z^4) z > 0"""
        nu += 1.0e-10  # to garanty that nu is not an integer
        z2 = z * z
        z4 = z2 * z2
        c1 = jnp.power(2.0, -6.0 - nu)
        c2 = _gamma(-2.0 - nu)
        c3 = _gamma(-2.0 + nu)
        c4 = jnp.power(z, 2.0 * nu)
        c5 = z4 * 8.0 * z2 * (2.0 + nu) + 32.0 * (1.0 + nu) * (2.0 + nu)
        c6 = z2 * (16.0 + z2 - 8.0 * nu) * c3
        return c1 * (c4 * c5 * c2 + jnp.power(4.0, nu) * (c6 + 32.0 * _gamma(nu)))

    return jnp.select(
        [nu == 0, nu == 1, nu == 2, nu == 3, nu == 4],
        [f0(z), f1(z), f2(z), f3(z), f4(z)],
        default=fnu(z, nu),
    )


@jax.jit
def fz_nu(z, nu):
    """z^nu K_nu[z] with z > 0"""
    return jnp.where(z <= 1.0e-10, fsmallz_nu(z, nu), jnp.power(z, nu) * kv(nu, z))


@jax.jit
def fsmallz_nup1(z, nu):
    def fnu(z, nu):
        """z^(nu+1) K_(nu+1)[z] with  z -> 0"""
        z2 = z * z
        z4 = z2 * z2
        c1 = -jnp.power(2.0, -4.0 - nu)
        c2 = _gamma(-2.0 - nu)
        c3 = c1 * c2 * (8.0 + 4.0 * nu + z2) * jnp.power(z, 2.0 * (1.0 + nu))
        c4 = jnp.power(2.0, nu)
        c5 = _gammap1(nu)
        c6 = c4 * c5 * (1.0 - 0.25 * z2 / nu + z4 * 0.03125 / (nu * (nu - 1.0)))
        return c3 + c6

    return jnp.select(
        [nu == 0, nu == 1, nu == 2, nu == 3, nu == 4],
        [f1(z), f2(z), f3(z), f4(z), f5(z)],
        default=fnu(z, nu),
    )


@jax.jit
def fz_nup1(z, nu):
    """z^(nu+1) K_{nu+1}(z)"""
    return jnp.where(
        z <= 1.0e-10, fsmallz_nup1(z, nu), jnp.power(z, nu + 1.0) * kv(nu + 1.0, z)
    )


@jax.jit
def fluxfractionFunc(z, nu, alpha):
    """1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)) - alpha"""
    return 1.0 - fz_nup1(z, nu) / (jnp.power(2.0, nu) * _gammap1(nu)) - alpha


@jax.jit
def reducedfluxfractionFunc(z, nu, norm):
    """(1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)))/norm"""
    return fluxfractionFunc(z, nu, alpha=0.0) / norm


@jax.jit
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=40.0):
    """Return radius R enclosing flux fraction alpha in unit of the scale radius r0

    Method: Solve  F(R/r0=z)/Flux - alpha = 0 using bisection algorithm

    F(R)/F =  int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha
    =>
    z=R/r0 such that
    1 - z^(nu+1) K_{nu+1}(z) / (2^nu Gamma(nu+1)) = alpha

    Typical use cases:
     o alpha = 1/2 => R = Half-Light-Radius,
     o alpha = 1 - folding-thresold => R used for stepk computation

     nu: the Spergel index

     nb. it is supposed that nu is in [-0.85, 4.0] checked in the Spergel class init
    """
    return bisect_for_root(
        partial(fluxfractionFunc, nu=nu, alpha=alpha),
        zmin,
        zmax,
        niter=75,
    )


def _spergel_hlr_pade(x):
    """A Pseudo-Pade approximation for the HLR of the Spergel profile as a function of nu.

    See dev/notebooks/spergel_hlr_flux_radius_approx.ipynb for code to generate this routine.
    """
    # fmt: off
    pm = 1.2571513771129166 + x * (
        3.7059053890269102 + x * (
            2.8577090425861944 + x * (
                -0.30570486567039273 + x * (
                    0.6589831675940833 + x * (
                        3.375577680133867 + x * (
                            2.8143565844741403 + x * (
                                0.9292378858457211 + x * (
                                    0.12096941981286179 + x * (
                                        0.004206502758293099
                                    )
                                )
                            )
                        )
                    )
                )
            )
        )
    )
    qm = 1.0 + x * (
        2.1939178810491837 + x * (
            0.8281034080784796 + x * (
                -0.5163329765186994 + x * (
                    0.9164871490929886 + x * (
                        1.8988551389326231 + x * (
                            1.042688817291684 + x * (
                                0.22580140592548198 + x * (
                                    0.01681923980317362 + x * (
                                        0.00018168506955933716
                                    )
                                )
                            )
                        )
                    )
                )
            )
        )
    )
    # fmt: on
    return pm / qm


LAX_SPERGEL_DESCRIPTION = r"""
The fully normalized Spergel profile (used in both standard GalSim and JAX-GalSim) is

.. math::
    I(r) = flux \times \left(2\pi 2^\nu \Gamma(1+\nu) r_0^2\right)^{-1} \times \left(\frac{r}{r_0}\right)^\nu K_\nu\left(\frac{r}{r_0}\right)

with the following Fourier expression

.. math::
    \hat{I}(k) = flux / (1 + (k r_0)^2)^{1+\nu}

where :math:`r_0` is the ``scale_radius``, and :math:`\nu` mandatory to be in [-0.85,4.0]

The JAX-GalSim implementation does not support autodiff with respect to :math:`\nu` for
real-space evaluations.
"""


[docs] @implements(_galsim.Spergel, lax_description=LAX_SPERGEL_DESCRIPTION) @register_pytree_node_class class Spergel(GSObject): _has_hard_edges = False _is_axisymmetric = True _is_analytic_x = True _is_analytic_k = True _minimum_nu = -0.85 _maximum_nu = 4.0 def __init__( self, nu, scale_radius=None, half_light_radius=None, flux=1.0, gsparams=None, ): # Parse the radius options if half_light_radius is not None: if scale_radius is not None: raise _galsim.GalSimIncompatibleValuesError( "Only one of scale_radius, half_light_radius may be specified", half_light_radius=half_light_radius, scale_radius=scale_radius, ) else: super().__init__( nu=nu, scale_radius=half_light_radius / _spergel_hlr_pade(nu), flux=flux, gsparams=gsparams, ) elif scale_radius is None: raise _galsim.GalSimIncompatibleValuesError( "One of scale_radius, half_light_radius must be specified", half_light_radius=half_light_radius, scale_radius=scale_radius, ) else: super().__init__( nu=nu, scale_radius=scale_radius, flux=flux, gsparams=gsparams, ) @property @implements(_galsim.spergel.Spergel.nu) def nu(self): return self._params["nu"] @property @implements(_galsim.spergel.Spergel.scale_radius) def scale_radius(self): return self.params["scale_radius"] @property def _r0(self): return self.scale_radius @property def _inv_r0(self): return 1.0 / self._r0 @property def _r0_sq(self): return self._r0 * self._r0 @property def _inv_r0_sq(self): return self._inv_r0 * self._inv_r0 @property @implements(_galsim.spergel.Spergel.half_light_radius) def half_light_radius(self): return self._r0 * _spergel_hlr_pade(self.nu) @property def _shootxnorm(self): """Normalization for photon shooting""" return 1.0 / (2.0 * jnp.pi * jnp.power(2.0, self.nu) * _gammap1(self.nu)) @property def _xnorm(self): """Normalization of xValue""" return self._shootxnorm * self.flux * self._inv_r0_sq @property def _xnorm0(self): """return z^nu K_nu(z) for z=0""" return jax.lax.select( self.nu > 0, _gamma(self.nu) * jnp.power(2.0, self.nu - 1.0), jnp.inf )
[docs] @implements(_galsim.spergel.Spergel.calculateFluxRadius) def calculateFluxRadius(self, f): return self._r0 * calculateFluxRadius(f, self.nu)
[docs] @implements(_galsim.spergel.Spergel.calculateIntegratedFlux) def calculateIntegratedFlux(self, r): return fluxfractionFunc(r / self._r0, self.nu, 0.0)
def __hash__(self): return hash( ( "galsim.Spergel", ensure_hashable(self.nu), ensure_hashable(self.scale_radius), ensure_hashable(self.flux), self.gsparams, ) ) def __repr__(self): return "galsim.Spergel(nu=%r, scale_radius=%r, flux=%r, gsparams=%r)" % ( ensure_hashable(self.nu), ensure_hashable(self.scale_radius), ensure_hashable(self.flux), self.gsparams, ) def __str__(self): s = "galsim.Spergel(nu=%s, half_light_radius=%s" % ( ensure_hashable(self.nu), ensure_hashable(self.half_light_radius), ) if self.flux != 1.0: s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s @property def _maxk(self): """(1+ (k r0)^2)^(-1-nu) = maxk_threshold""" res = jnp.power(self.gsparams.maxk_threshold, -1.0 / (1.0 + self.nu)) - 1.0 return jnp.sqrt(res) / self._r0 @property def _stepk(self): R = calculateFluxRadius(1.0 - self.gsparams.folding_threshold, self.nu) R *= self._r0 # Go to at least 5*hlr R = jnp.maximum(R, self.gsparams.stepk_minimum_hlr * self.half_light_radius) return jnp.pi / R @property def _max_sb(self): # from SBSpergelImpl.h return jnp.abs(self._xnorm) * self._xnorm0 @jax.jit def _xValue(self, pos): r = jnp.sqrt(pos.x**2 + pos.y**2) * self._inv_r0 res = jnp.where(r == 0, self._xnorm0, fz_nu(r, jax.lax.stop_gradient(self.nu))) return self._xnorm * res @jax.jit def _kValue(self, kpos): ksq = (kpos.x**2 + kpos.y**2) * self._r0_sq return self.flux * jnp.power(1.0 + ksq, -1.0 - self.nu) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): _jac = jnp.eye(2) if jac is None else jac return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac)
[docs] @implements(_galsim.Spergel.withFlux) def withFlux(self, flux): return Spergel( nu=self.nu, scale_radius=self.scale_radius, flux=flux, gsparams=self.gsparams, )
@property def _shoot_pos_cdf(self): zmax = calculateFluxRadius( 1.0 - self.gsparams.shoot_accuracy, self.nu, zmax=30.0 ) flux_max = fluxfractionFunc(zmax, self.nu, alpha=0.0) preducedfluxfractionFunc = partial( reducedfluxfractionFunc, nu=self.nu, norm=flux_max ) z_cdf = jnp.linspace(0, zmax, 10_000) cdf = preducedfluxfractionFunc(z_cdf) return z_cdf, cdf def _shoot_pos(self, u): # shoot r in case of nu>0 z_cdf, cdf = self._shoot_pos_cdf z = jnp.interp(u, cdf, z_cdf) # linear inversion of the CDF r = z * self._r0 return r @property def _shoot_neg_cdf(self): # comment: # In the Galsim code the profile below rmin is linearized such that # call zmin = rmin/r0 such that # Int_0^zmin 2pi u x I(u) du = shoot_accuracy # Then let (a,b) such that # 1) Int_0^zmin 2pi u x (a + b u) du = shoot_accuracy # 2) a + b zmin = zmin^nu K_nu(zmin) # Now, noticing that # I(z) = z^nu K_nu(z) / (2pi 2^nu Gamma(nu+1)) = z^nu K_nu(z)/(2 pi Nnu) # there is a problem with eq. 1 as we would have expected # 1b) Int_0^zmin 2pi u x (a + b u)/(2 pi Nnu) du = shoot_accuracy # so the corrFact is there to signal the changement in this implementation zmax = calculateFluxRadius( 1.0 - self.gsparams.shoot_accuracy, self.nu, zmax=30.0 ) flux_target = self.gsparams.shoot_accuracy shoot_rmin = calculateFluxRadius(flux_target, self.nu) knur = fz_nu(shoot_rmin, self.nu) corrFact = self._shootxnorm # this is the correct normalisation b = knur - flux_target / (jnp.pi * shoot_rmin * shoot_rmin * corrFact) b = 3.0 * b / shoot_rmin a = knur - shoot_rmin * b def cumulflux(z, a, b, zmin, nu, norm=1.0): flux_min = a / 3.0 * zmin * zmin * zmin + b / 2.0 * zmin * zmin c1 = fz_nup1(zmin, nu) res = jnp.where( z <= zmin, a / 3.0 * z * z * z + b / 2.0 * z * z, flux_min + c1 - fz_nup1(z, nu), ) return res / norm flux_max = cumulflux(zmax, a, b, shoot_rmin, self.nu) preducedfluxfractionFunc = partial( cumulflux, a=a, b=b, zmin=shoot_rmin, nu=self.nu, norm=flux_max ) z_cdf = jnp.linspace(0, zmax, 10_000) cdf = preducedfluxfractionFunc(z_cdf) return z_cdf, cdf def _shoot_neg(self, u): # shoot r in case of nu<=0 z_cdf, cdf = self._shoot_neg_cdf z = jnp.interp(u, cdf, z_cdf) # linear inversion of the CDF r = z * self._r0 return r @implements(_galsim.Spergel._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) u = ud.generate(photons.x) r = jax.lax.select(self.nu > 0, self._shoot_pos(u), self._shoot_neg(u)) ang = ud.generate(photons.x) * 2.0 * jnp.pi photons.x = r * jnp.cos(ang) photons.y = r * jnp.sin(ang) photons.flux = self.flux / photons.size()