import re
import textwrap
from functools import partial
from typing import NamedTuple
import equinox
import jax
import jax.numpy as jnp
import numpy as np
STATIC_SCALAR_TYPES = (int, float, np.integer, np.floating)
[docs]
def check_is_int_then_cast(val, msg):
"""Check if `val` is an integer, raise if not, otherwise cast to int."""
val = cast_to_float(val)
if isinstance(val, STATIC_SCALAR_TYPES):
# for simple inputs, we can check direct in python
if val != int(val):
raise TypeError(msg)
val = int(val)
else:
# otherwise we use more opaque checking upon jit via equinox
val = jnp.array(val)
val = equinox.error_if(
val,
jnp.any(val != jnp.trunc(val)),
msg,
)
val = val.astype(int)
return val
[docs]
def cast_numpy_array_to_native_byte_order(arr):
"""Cast an array to native byte order."""
if not isinstance(arr, np.ndarray):
return arr
if arr.dtype.isnative:
return arr
return arr.astype(arr.dtype.newbyteorder("="))
def _cast_to_type(x, typ, accept_strings=False):
if isinstance(x, STATIC_SCALAR_TYPES) or (accept_strings and isinstance(x, str)):
return typ(x)
else:
return jnp.astype(x, typ)
[docs]
def cast_to_float(x, accept_strings=False):
"""Cast the input to a float. Works on python floats/ints, numpy scalars, and jax/numpy arrays.
Parameters:
accept_strings: If True, allow string to ``float`` conversion. [default: False]
Returns:
Input value ``x`` casted to a ``float``.
"""
# use the python `float` const/func here to promote to the highest
# precision available without emitting a warning in JAX
return _cast_to_type(x, float, accept_strings=accept_strings)
[docs]
def cast_to_int(x, accept_strings=False):
"""Cast the input to an int. Works on python floats/ints, numpy scalars, and jax/numpy arrays.
Parameters:
accept_strings: If True, allow string to ``int`` conversion. [default: False]
Returns:
Input value ``x`` casted to an ``int``.
"""
return _cast_to_type(x, int, accept_strings=accept_strings)
[docs]
def is_equal_with_arrays(x, y, no_jax=False):
"""Return True if the data is equal, False otherwise. Handles jax.Array types."""
if no_jax:
arr_func = np.array
arr_eq_func = np.array_equal
else:
arr_func = jnp.array
arr_eq_func = jnp.array_equal
currval = arr_func(True)
if isinstance(x, list):
if isinstance(y, list) and len(x) == len(y):
for vx, vy in zip(x, y):
currval &= is_equal_with_arrays(vx, vy, no_jax=no_jax)
else:
currval &= arr_func(False)
elif isinstance(x, tuple):
if isinstance(y, tuple) and len(x) == len(y):
for vx, vy in zip(x, y):
currval &= is_equal_with_arrays(vx, vy, no_jax=no_jax)
else:
currval &= arr_func(False)
elif isinstance(x, set):
if isinstance(y, set) and len(x) == len(y):
for vx, vy in zip(sorted(x), sorted(y)):
currval &= is_equal_with_arrays(vx, vy, no_jax=no_jax)
else:
currval &= arr_func(False)
elif isinstance(x, dict):
if isinstance(y, dict) and len(x) == len(y):
for kx, vx in x.items():
if kx not in y:
currval &= arr_func(False)
else:
currval &= is_equal_with_arrays(vx, y[kx], no_jax=no_jax)
else:
currval &= arr_func(False)
elif isinstance(x, jax.Array) and jnp.ndim(x) > 0:
if isinstance(y, jax.Array) and y.shape == x.shape:
currval &= arr_eq_func(x, y)
else:
currval &= arr_func(False)
elif (isinstance(x, jax.Array) and jnp.ndim(x) == 0) or (
isinstance(y, jax.Array) and jnp.ndim(y) == 0
):
# this case covers comparing an array scalar to a python scalar or vice versa
currval &= arr_eq_func(x, y)
else:
currval &= arr_func(x == y)
return currval
def _convert_to_numpy_nan(x):
"""Convert input to numpy.nan if it is a NaN, otherwise return it unchanged
so that we get consistent hashing."""
try:
if np.isnan(x):
return np.nan
else:
return x
except Exception:
return x
def _recurse_list_to_tuple(x):
if isinstance(x, list):
return tuple(_recurse_list_to_tuple(v) for v in x)
else:
return _convert_to_numpy_nan(x)
[docs]
def ensure_hashable(v):
"""Ensure that the input is hashable. If it is a jax array,
convert it to a possibly nested tuple or python scalar.
All NaNs are converted to numpy.nan to get consistent hashing.
"""
if isinstance(v, jax.Array):
try:
if len(v.shape) > 0:
return _recurse_list_to_tuple(v.tolist())
else:
return _convert_to_numpy_nan(v.item())
except Exception:
return _convert_to_numpy_nan(v)
else:
return _convert_to_numpy_nan(v)
[docs]
@partial(jax.jit, static_argnames=("niter",))
def bisect_for_root(func, low, high, niter=75):
def _func(i, args):
func, low, flow, high, fhigh = args
mid = (low + high) / 2.0
fmid = func(mid)
return jax.lax.cond(
fmid * fhigh < 0,
lambda func, low, flow, mid, fmid, high, fhigh: (
func,
mid,
fmid,
high,
fhigh,
),
lambda func, low, flow, mid, fmid, high, fhigh: (
func,
low,
flow,
mid,
fmid,
),
func,
low,
flow,
mid,
fmid,
high,
fhigh,
)
flow = func(low)
fhigh = func(high)
args = (func, low, flow, high, fhigh)
return jax.lax.fori_loop(0, niter, _func, args, unroll=15)[-2]
# start of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py #
# used with modifications for galsim under the following license:
# fmt: off
#
# Copyright 2020 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# fmt: on
_docreference = re.compile(r":doc:`(.*?)\s*<.*?>`")
[docs]
class ParsedDoc(NamedTuple):
"""
docstr: full docstring
signature: signature from docstring.
summary: summary from docstring.
front_matter: front matter before sections.
sections: dictionary of section titles to section content.
"""
docstr: str = ""
signature: str = ""
summary: str = ""
front_matter: str = ""
sections: dict[str, str] = {}
def _break_off_body_section_by_newline(body, double_check_first_indent=False):
first_lines = []
body_lines = []
found_first_break = False
for line in body.split("\n"):
if not first_lines:
first_lines.append(line)
continue
if not line.strip() and not found_first_break:
found_first_break = True
continue
if found_first_break:
body_lines.append(line)
else:
first_lines.append(line)
if double_check_first_indent and len(first_lines) > 1:
len_first_indent = len(first_lines[1]) - len(first_lines[1].lstrip())
if len_first_indent > 0:
first_indent = first_lines[1][:len_first_indent]
first_lines[0] = first_indent + first_lines[0].lstrip()
firstline = "\n".join(first_lines)
firstline = textwrap.dedent(firstline)
body = "\n".join(body_lines)
body = textwrap.dedent(body.lstrip("\n"))
return firstline, body
def _parse_galsimdoc(docstr):
"""Parse a standard galsim-style docstring.
Args:
docstr: the raw docstring from a function
Returns:
ParsedDoc: parsed version of the docstring
"""
if docstr is None or not docstr.strip():
return ParsedDoc(docstr)
# Remove any :doc: directives in the docstring to avoid sphinx errors
docstr = _docreference.sub(lambda match: f"{match.groups()[0]}", docstr)
signature, body = "", docstr
firstline, body = _break_off_body_section_by_newline(
body, double_check_first_indent=True
)
summary = firstline
if not summary:
summary, body = _break_off_body_section_by_newline(body)
front_matter_lines = []
body_lines = []
found_params = False
for line in body.split("\n"):
if not found_params and line.lstrip().startswith("Parameters:"):
found_params = True
if found_params:
body_lines.append(line)
else:
front_matter_lines.append(line)
front_matter = "\n".join(front_matter_lines)
body = "\n".join(body_lines)
# we add back the body for now, but keep code above if we parse params in the future
front_matter = front_matter + "\n" + body
return ParsedDoc(
docstr=docstr,
signature=signature,
summary=summary,
front_matter=front_matter,
sections={},
)
[docs]
def implements(
original_fun,
lax_description="",
module=None,
):
"""Decorator for JAX functions which implement a specified GalSim function.
This mainly contains logic to copy and modify the docstring of the original
function. In particular, if `update_doc` is True, parameters listed in the
original function that are not supported by the decorated function will
be removed from the docstring. For this reason, it is important that parameter
names match those in the original GalSim function.
Parameters:
original_fun: The original function being implemented
lax_description: A string description that will be added to the beginning of
the docstring.
module: An optional string specifying the module from which the original function
is imported. This is useful for objects, where the module cannot
be determined from the original function itself.
"""
def decorator(wrapped_fun):
wrapped_fun.__galsim_wrapped__ = original_fun
# Allows this pattern: @implements(getattr(np, 'new_function', None))
if original_fun is None:
if lax_description:
wrapped_fun.__doc__ = lax_description
return wrapped_fun
docstr = getattr(original_fun, "__doc__", None)
name = getattr(
original_fun, "__name__", getattr(wrapped_fun, "__name__", str(wrapped_fun))
)
try:
mod = module or original_fun.__module__
except AttributeError:
pass
else:
name = f"{mod}.{name}"
if docstr:
try:
parsed = _parse_galsimdoc(docstr)
docstr = parsed.summary.strip() + "\n" if parsed.summary else ""
docstr += f"\nLAX-backend implementation of :func:`{name}`.\n"
if lax_description:
docstr += "\n" + lax_description.strip() + "\n"
if parsed.front_matter:
docstr += "\n*Original docstring below.*\n"
docstr += "\n" + parsed.front_matter.strip() + "\n"
except Exception:
docstr = original_fun.__doc__
wrapped_fun.__doc__ = docstr
for attr in ["__name__", "__qualname__"]:
try:
value = getattr(original_fun, attr)
except AttributeError:
pass
else:
setattr(wrapped_fun, attr, value)
return wrapped_fun
return decorator
# end of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py #