import functools
import jax
import jax.numpy as jnp
import numpy as np
[docs]
def akima_interp_coeffs(x, y, use_jax=True):
"""Compute the interpolation coefficients for an Akima cubic spline.
An Akima cubic spline is a piecewise C(1) cubic polynomial that interpolates a set of
points (x, y). Unlike a more traditional cubic spline, the Akima spline can be computed
without solving a linear system of equations. However, the Akima spline does not have
continuous second derivatives at the interpolation points.
See https://en.wikipedia.org/wiki/Akima_spline and
Akima (1970), "A new method of interpolation and smooth curve fitting based on local
procedures", Journal of the ACM. 17: 589-602 for a description of the technique.
Parameters:
x: The x-coordinates of the data points. These must be sorted into
increasing order and cannot contain any duplicates.
y: The y-coordinates of the data points.
use_jax: Whether to use JAX for computation. If False, coefficients are computed
using NumPy on the host device, which can be useful when embedded inside
JAX code with JIT applied to pre-compute the coefficients. [default: True]
Returns:
A tuple of arrays ``(a, b, c, d)`` where each array has shape ``(N-1,)``
and contains the coefficients for the cubic polynomial that interpolates
the data points between ``x[i]`` and ``x[i+1]``.
"""
if use_jax:
return _akima_interp_coeffs_jax(x, y)
else:
return _akima_interp_coeffs_nojax(x, y)
def _akima_interp_coeffs_nojax(x, y):
dx = x[1:] - x[:-1]
mi = (y[1:] - y[:-1]) / dx
# these values are imposed for points
# at the ends
s0 = mi[0:1]
s1 = (mi[0:1] + mi[1:2]) / 2.0
snm2 = (mi[-3:-2] + mi[-2:-1]) / 2.0
snm1 = mi[-2:-1]
wim1 = np.abs(mi[3:] - mi[2:-1])
wi = np.abs(mi[1:-2] - mi[0:-3])
denom = wim1 + wi
numer = wim1 * mi[1:-2] + wi * mi[2:-1]
msk_denom = np.abs(denom) >= 1e-12
smid = np.zeros_like(denom)
smid[msk_denom] = numer[msk_denom] / denom[msk_denom]
smid[~msk_denom] = (mi[1:-2][~msk_denom] + mi[2:-1][~msk_denom]) / 2.0
s = np.concatenate([s0, s1, smid, snm2, snm1])
# these coeffs are for
# P(x) = a + b * (x-xi) + c * (x-xi)**2 + d * (x-xi)**3
# for a point x that falls in [xi, xip1]
a = y[:-1]
b = s[:-1]
c = (3 * mi - 2 * s[:-1] - s[1:]) / dx
d = (s[:-1] + s[1:] - 2 * mi) / dx / dx
return (a, b, c, d)
@jax.jit
def _akima_interp_coeffs_jax(x, y):
dx = x[1:] - x[:-1]
mi = (y[1:] - y[:-1]) / dx
# these values are imposed for points
# at the ends
s0 = mi[0:1]
s1 = (mi[0:1] + mi[1:2]) / 2.0
snm2 = (mi[-3:-2] + mi[-2:-1]) / 2.0
snm1 = mi[-2:-1]
wim1 = jnp.abs(mi[3:] - mi[2:-1])
wi = jnp.abs(mi[1:-2] - mi[0:-3])
denom = wim1 + wi
numer = wim1 * mi[1:-2] + wi * mi[2:-1]
smid = jnp.where(
jnp.abs(denom) >= 1e-12,
numer / denom,
(mi[1:-2] + mi[2:-1]) / 2.0,
)
s = jnp.concatenate([s0, s1, smid, snm2, snm1])
# these coeffs are for
# P(x) = a + b * (x-xi) + c * (x-xi)**2 + d * (x-xi)**3
# for a point x that falls in [xi, xip1]
a = y[:-1]
b = s[:-1]
c = (3 * mi - 2 * s[:-1] - s[1:]) / dx
d = (s[:-1] + s[1:] - 2 * mi) / dx / dx
return (a, b, c, d)
[docs]
@functools.partial(jax.jit, static_argnames=("fixed_spacing",))
def akima_interp(x, xp, yp, coeffs, fixed_spacing=False):
"""Compute the values of an Akima cubic spline at a set of points given the
interpolation coefficients.
Parameters:
x: The x-coordinates of the points where the interpolation is computed.
xp: The x-coordinates of the data points. These must be sorted into
increasing order and cannot contain any duplicates.
yp: The y-coordinates of the data points. Not used currently.
coeffs: The interpolation coefficients returned by ``akima_interp_coeffs``.
fixed_spacing: Whether the data points are evenly spaced. If True, a faster
technique is used to find the index ``i`` such that
``xp[i] <= x < xp[i+1]``. [default: False]
Returns:
The values of the Akima cubic spline at the points ``x``.
"""
xp = jnp.asarray(xp)
# yp = jnp.array(yp) # unused
if fixed_spacing:
dxp = xp[1] - xp[0]
i = jnp.floor((x - xp[0]) / dxp).astype(jnp.int32)
i = jnp.clip(i, 0, len(xp) - 2)
else:
i = jnp.clip(jnp.searchsorted(xp, x, side="right"), 1, len(xp) - 1) - 1
# these coeffs are for
# P(x) = a + b * (x-xi) + c * (x-xi)**2 + d * (x-xi)**3
# for a point x that falls in [xi, xip1]
a, b, c, d = coeffs
a = jnp.asarray(a)
b = jnp.asarray(b)
c = jnp.asarray(c)
d = jnp.asarray(d)
dx = x - xp[i]
dx2 = dx * dx
dx3 = dx2 * dx
xval = a[i] + b[i] * dx + c[i] * dx2 + d[i] * dx3
xval = jnp.where(x < xp[0], 0, xval)
xval = jnp.where(x > xp[-1], 0, xval)
return xval