Source code for jax_galsim.core.integrate

from functools import partial
from typing import NamedTuple, Tuple

import jax.numpy as jnp
from jax import Array, jit


[docs] def abs_weights(n: int): assert n > 1 points = -jnp.cos(jnp.linspace(0, jnp.pi, n)) if n == 2: weights = jnp.array([1.0, 1.0]) return points, weights n -= 1 N = jnp.arange(1, n, 2) length = len(N) m = n - length v0 = jnp.concatenate([2.0 / N / (N - 2), jnp.array([1.0 / N[-1]]), jnp.zeros(m)]) v2 = -v0[:-1] - v0[:0:-1] g0 = -jnp.ones(n) g0 = g0.at[length].add(n) g0 = g0.at[m].add(n) g = g0 / (n**2 - 1 + (n % 2)) w = jnp.fft.ihfft(v2 + g) w = w.real if n % 2 == 1: weights = jnp.concatenate([w, w[::-1]]) else: weights = jnp.concatenate([w, w[len(w) - 2 :: -1]]) return points, weights
[docs] class ClenshawCurtisQuad(NamedTuple): order: int absc: Array absw: Array errw: Array
[docs] @classmethod def init(cls, order: int): order = 2 * order + 1 absc, absw, errw = cls.compute_weights(order) absc, absw = cls.rescale_weights(absc, absw) return cls(order=order, absc=absc, absw=absw, errw=errw)
[docs] @staticmethod def compute_weights(order: int): x, wx = abs_weights(order) nsub = (order + 1) // 2 _, wsub = abs_weights(nsub) errw = wx.at[::2].add(-wsub) return x, wx, errw
[docs] @staticmethod def rescale_weights( absc: Array, absw: Array, *, interval_in: Tuple[float, float] = (-1, 1), interval_out: Tuple[float, float] = (0, 1), ): (in_min, in_max), (out_min, out_max) = interval_in, interval_out delta_in, delta_out = in_max - in_min, out_max - out_min absc = ((absc - in_min) * out_max - (absc - in_max) * out_min) / delta_in absw = delta_out / delta_in * absw return absc, absw
[docs] @partial(jit, static_argnums=(0,)) def quad_integral(f, a, b, quad: ClenshawCurtisQuad): a = jnp.atleast_1d(a) b = jnp.atleast_1d(b) d = b - a xi = a[jnp.newaxis, :] + jnp.einsum("i...,k...->ik...", quad.absc, d) xi = xi.squeeze() fi = f(xi) S = d * jnp.einsum("i...,i...", quad.absw, fi) return S.squeeze()