Source code for jax_galsim.core.math
import jax
import jax.numpy as jnp
[docs]
@jax.jit
def safe_sqrt(x):
"""Numerically safe sqrt operation with zero derivative at zero."""
msk = x > 0
x_msk = jnp.where(msk, x, 1.0)
return jnp.where(msk, jnp.sqrt(x_msk), 0.0)