import copy
import math
from functools import partial
import galsim as _galsim
import jax
import jax.numpy as jnp
import jax.random as jrng
from galsim.errors import (
GalSimIncompatibleValuesError,
GalSimRangeError,
GalSimUndefinedBoundsError,
GalSimValueError,
)
from galsim.utilities import doc_inherit
from jax.tree_util import register_pytree_node_class
from jax_galsim import fits
from jax_galsim.bounds import BoundsI
from jax_galsim.core.utils import (
ensure_hashable,
implements,
)
from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.image import Image
from jax_galsim.interpolant import Quintic
from jax_galsim.photon_array import PhotonArray
from jax_galsim.position import PositionD
from jax_galsim.transform import Transformation
from jax_galsim.utilities import convert_interpolant
from jax_galsim.wcs import BaseWCS, PixelScale
# These keys are removed from the public API of
# InterpolatedImage so that it matches the galsim
# one.
# The DirMeta class does this along with the changes to
# __getattribute__ and __dir__ below.
_KEYS_TO_REMOVE = [
"flux_ratio",
"jac",
"offset",
"original",
]
# magic from https://stackoverflow.com/questions/46120462/how-to-override-the-dir-method-for-a-class
class DirMeta(type):
def __dir__(cls):
keys = set(list(cls.__dict__.keys()) + dir(cls.__base__))
keys -= set(_KEYS_TO_REMOVE)
return list(keys)
LAX_INTERPOLATED_IMAGE = """\
The JAX equivalent of galsim.InterpolatedImage does not support:
- noise padding
- the pad_image options
- depixelize
- most of the bounds checks, type checks, and dtype casts done by galsim
"""
[docs]
@implements(_galsim.InterpolatedImage, lax_description=LAX_INTERPOLATED_IMAGE)
@register_pytree_node_class
class InterpolatedImage(Transformation, metaclass=DirMeta):
_req_params = {"image": str}
_opt_params = {
"x_interpolant": str,
"k_interpolant": str,
"normalization": str,
"scale": float,
"flux": float,
"pad_factor": float,
"noise_pad_size": float,
"noise_pad": str,
"pad_image": str,
"calculate_stepk": bool,
"calculate_maxk": bool,
"use_true_center": bool,
"depixelize": bool,
"offset": PositionD,
"hdu": int,
}
_takes_rng = True
def __init__(
self,
image,
x_interpolant=None,
k_interpolant=None,
normalization="flux",
scale=None,
wcs=None,
flux=None,
pad_factor=4.0,
noise_pad_size=0,
noise_pad=0.0,
rng=None,
pad_image=None,
calculate_stepk=True,
calculate_maxk=True,
use_cache=True,
use_true_center=True,
depixelize=False,
offset=None,
gsparams=None,
_force_stepk=0.0,
_force_maxk=0.0,
_recenter_image=True, # this option is used by _InterpolatedImage below
hdu=None,
_obj=None,
):
# If the "image" is not actually an image, try to read the image as a file.
if isinstance(image, str):
image = fits.read(image, hdu=hdu)
elif not isinstance(image, Image):
raise TypeError("Supplied image must be an Image or file name")
self._jax_children = (
image,
dict(
scale=scale,
wcs=wcs,
flux=flux,
pad_image=pad_image,
offset=offset,
),
)
self._jax_aux_data = dict(
x_interpolant=x_interpolant,
k_interpolant=k_interpolant,
normalization=normalization,
pad_factor=pad_factor,
noise_pad_size=noise_pad_size,
noise_pad=noise_pad,
rng=rng,
calculate_stepk=calculate_stepk,
calculate_maxk=calculate_maxk,
use_cache=use_cache,
use_true_center=use_true_center,
depixelize=depixelize,
gsparams=GSParams.check(gsparams),
_force_stepk=_force_stepk,
_force_maxk=_force_maxk,
_recenter_image=_recenter_image,
hdu=hdu,
)
if _obj is not None:
obj = _obj
else:
obj = _InterpolatedImageImpl(
image,
x_interpolant=x_interpolant,
k_interpolant=k_interpolant,
normalization=normalization,
scale=scale,
wcs=wcs,
flux=flux,
pad_factor=pad_factor,
noise_pad_size=noise_pad_size,
noise_pad=noise_pad,
rng=rng,
pad_image=pad_image,
calculate_stepk=calculate_stepk,
calculate_maxk=calculate_maxk,
use_cache=use_cache,
use_true_center=use_true_center,
depixelize=depixelize,
offset=offset,
gsparams=GSParams.check(gsparams),
hdu=hdu,
_recenter_image=_recenter_image,
)
# we don't use the parent init but instead set things by hand to
# avoid computations upon init
self._gsparams = GSParams.check(gsparams, obj.gsparams)
self._propagate_gsparams = True
if self._propagate_gsparams:
obj = obj.withGSParams(self._gsparams)
self._original = obj
self._params = {
"offset": PositionD(0.0, 0.0),
}
self._jax_children[1]["_obj"] = obj
@property
def _flux_ratio(self):
return self._original._flux_ratio / self._original._wcs.pixelArea()
@property
def _jac(self):
return self._original._jac_arr.reshape((2, 2))
def __getattribute__(self, name):
if name in _KEYS_TO_REMOVE:
raise AttributeError(f"{self.__class__} has no attribute '{name}'")
return super().__getattribute__(name)
def __dir__(self):
allattrs = set(self.__dict__.keys() + dir(self.__class__))
allattrs -= set(_KEYS_TO_REMOVE)
return list(allattrs)
# the galsim tests use this internal attribute
# so we add it here
@property
def _xim(self):
return self._original._xim
@property
def _maxk(self):
if self._jax_aux_data["_force_maxk"] > 0:
return self._jax_aux_data["_force_maxk"]
else:
# galsim uses a different way to handle the WCS effects on maxk
# for interpolated images. IDK why. - MRB
return self._original.maxk / self._original._wcs._maxScale()
@property
def _stepk(self):
if self._jax_aux_data["_force_stepk"] > 0:
return self._jax_aux_data["_force_stepk"]
else:
# galsim uses a different way to handle the WCS effects on stepk
# for interpolated images. IDK why. - MRB
# super()._stepk
return self._original.stepk / self._original._wcs._minScale()
@property
@implements(_galsim.interpolatedimage.InterpolatedImage.x_interpolant)
def x_interpolant(self):
return self._original._x_interpolant
@property
@implements(_galsim.interpolatedimage.InterpolatedImage.k_interpolant)
def k_interpolant(self):
return self._original._k_interpolant
@property
@implements(_galsim.interpolatedimage.InterpolatedImage.image)
def image(self):
return self._original._image
def __hash__(self):
# Definitely want to cache this, since the size of the image could be large.
if not hasattr(self, "_hash"):
self._hash = hash(
("galsim.InterpolatedImage", self.x_interpolant, self.k_interpolant)
)
self._hash ^= hash(
(
ensure_hashable(self.flux),
ensure_hashable(self._stepk),
ensure_hashable(self._maxk),
ensure_hashable(self._original._jax_aux_data["pad_factor"]),
)
)
self._hash ^= hash(
(
self._original._xim.bounds,
self._original._image.bounds,
self._original._pad_image.bounds,
)
)
# A common offset is 0.5,0.5, and *sometimes* this produces the same hash as 0,0
# (which is also common). I guess because they are only different in 2 bits.
# This mucking of the numbers seems to help make the hash more reliably different for
# these two cases. Note: "sometiems" because of this:
# https://stackoverflow.com/questions/27522626/hash-function-in-python-3-3-returns-different-results-between-sessions
self._hash ^= hash(
(
ensure_hashable(self._original._offset.x * 1.234),
ensure_hashable(self._original._offset.y * 0.23424),
)
)
self._hash ^= hash(self.gsparams)
self._hash ^= hash(self._original._wcs)
# Just hash the diagonal. Much faster, and usually is unique enough.
# (Let python handle collisions as needed if multiple similar IIs are used as keys.)
self._hash ^= hash(ensure_hashable(self._original._pad_image.array))
return self._hash
def __repr__(self):
# this can happen due to incomplete initialization
_original = getattr(self, "_original", None)
if _original is None:
return "galsim.InterpolatedImage(None)"
s = "galsim.InterpolatedImage(%r, %r, %r, wcs=%r" % (
self._original.image,
self.x_interpolant,
self.k_interpolant,
self._original._wcs,
)
# Most things we keep even if not required, but the pad_image is large, so skip it
# if it's really just the same as the main image.
if self._original._pad_image.bounds != self._original.image.bounds:
s += ", pad_image=%r" % (self._pad_image)
s += ", pad_factor=%f, flux=%r, offset=%r" % (
ensure_hashable(self._original._jax_aux_data["pad_factor"]),
ensure_hashable(self.flux),
self._original._offset,
)
s += (
", use_true_center=False, gsparams=%r, _force_stepk=%r, _force_maxk=%r)"
% (
self.gsparams,
ensure_hashable(self._stepk),
ensure_hashable(self._maxk),
)
)
return s
def __str__(self):
return "galsim.InterpolatedImage(image=%s, flux=%s)" % (self.image, self.flux)
def __eq__(self, other):
return self is other or (
isinstance(other, InterpolatedImage)
and self._xim == other._xim
and self.x_interpolant == other.x_interpolant
and self.k_interpolant == other.k_interpolant
and self.flux == other.flux
and self._original._offset == other._original._offset
and self.gsparams == other.gsparams
and self._stepk == other._stepk
and self._maxk == other._maxk
)
[docs]
def tree_flatten(self):
"""This function flattens the InterpolatedImage into a list of children
nodes that will be traced by JAX and auxiliary static data."""
return (self._jax_children, copy.copy(self._jax_aux_data))
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
val = {}
val.update(aux_data)
val.update(children[1])
return cls(children[0], **val)
[docs]
@implements(_galsim.InterpolatedImage.withGSParams)
def withGSParams(self, gsparams=None, **kwargs):
if gsparams == self.gsparams:
return self
# Checking gsparams
gsparams = GSParams.check(gsparams, self.gsparams, **kwargs)
# Flattening the representation to instantiate a clean new object
children, aux_data = self.tree_flatten()
aux_data["gsparams"] = gsparams
ret = self.tree_unflatten(aux_data, children)
return ret
@partial(jax.jit, static_argnums=(1,))
def _zeropad_image(arr, npad):
return jnp.pad(arr, npad, mode="constant", constant_values=0.0)
@register_pytree_node_class
class _InterpolatedImageImpl(GSObject):
"""Internal class for handling interpolated images.
An interpolated image carries an intrinsic WCS with it that can be anything
from a pixel-scale to a full Jacobian.
We use this internal class to separate the underlying image bits from the
WCS handling bits. For those, we inherit from the Transform class so that
we can reuse its methods.
"""
_cache_noise_pad = {}
_has_hard_edges = False
_is_axisymmetric = False
_is_analytic_x = True
_is_analytic_k = True
def __init__(
self,
image,
x_interpolant=None,
k_interpolant=None,
normalization="flux",
scale=None,
wcs=None,
flux=None,
pad_factor=4.0,
noise_pad_size=0,
noise_pad=0.0,
rng=None,
pad_image=None,
calculate_stepk=True,
calculate_maxk=True,
use_cache=True,
use_true_center=True,
depixelize=False,
offset=None,
gsparams=None,
hdu=None,
_recenter_image=True,
):
# this class does a ton of munging of the inputs that I don't want to reconstruct when
# flattening and unflattening the class.
# thus I am going to make some refs here so we have it when we need it
self._jax_children = (
image,
dict(
scale=scale,
wcs=wcs,
flux=flux,
pad_image=pad_image,
offset=offset,
),
)
self._jax_aux_data = dict(
x_interpolant=x_interpolant,
k_interpolant=k_interpolant,
normalization=normalization,
pad_factor=pad_factor,
noise_pad_size=noise_pad_size,
noise_pad=noise_pad,
rng=rng,
calculate_stepk=calculate_stepk,
calculate_maxk=calculate_maxk,
use_cache=use_cache,
use_true_center=use_true_center,
depixelize=depixelize,
gsparams=gsparams,
_recenter_image=_recenter_image,
hdu=hdu,
)
# it must have well-defined bounds, otherwise seg fault in SBInterpolatedImage constructor
if not image.bounds.isDefined():
raise GalSimUndefinedBoundsError(
"Supplied image does not have bounds defined."
)
# check what normalization was specified for the image: is it an image of surface
# brightness, or flux?
if normalization.lower() not in ("flux", "f", "surface brightness", "sb"):
raise GalSimValueError(
"Invalid normalization requested.",
normalization,
("flux", "f", "surface brightness", "sb"),
)
# Set up the interpolants if none was provided by user, or check that the user-provided ones
# are of a valid type
self._gsparams = GSParams.check(gsparams)
if x_interpolant is None:
self._x_interpolant = Quintic(gsparams=self._gsparams)
else:
self._x_interpolant = convert_interpolant(x_interpolant).withGSParams(
self._gsparams
)
if k_interpolant is None:
self._k_interpolant = Quintic(gsparams=self._gsparams)
else:
self._k_interpolant = convert_interpolant(k_interpolant).withGSParams(
self._gsparams
)
if pad_image is not None:
raise NotImplementedError("pad_image not implemented in jax_galsim.")
if pad_factor <= 0.0:
raise GalSimRangeError(
"Invalid pad_factor <= 0 in InterpolatedImage", pad_factor, 0.0
)
if noise_pad_size:
raise NotImplementedError(
"InterpolatedImages do not support noise padding in jax_galsim."
)
else:
if noise_pad:
raise NotImplementedError(
"InterpolatedImages do not support noise padding in jax_galsim."
)
if scale is not None:
if wcs is not None:
raise GalSimIncompatibleValuesError(
"Cannot provide both scale and wcs to InterpolatedImage",
scale=self._jax_children[1]["scale"],
wcs=self._jax_children[1]["wcs"],
)
elif wcs is not None:
if not isinstance(wcs, BaseWCS):
raise TypeError("wcs parameter is not a galsim.BaseWCS instance")
else:
if self._jax_children[0].wcs is None:
raise GalSimIncompatibleValuesError(
"No information given with Image or keywords about pixel scale!",
scale=self._jax_children[1]["scale"],
wcs=self._jax_children[1]["wcs"],
image=self._jax_children[0],
)
@doc_inherit
def withGSParams(self, gsparams=None, **kwargs):
if gsparams == self.gsparams:
return self
# Checking gsparams
gsparams = GSParams.check(gsparams, self.gsparams, **kwargs)
# Flattening the representation to instantiate a clean new object
children, aux_data = self.tree_flatten()
aux_data["gsparams"] = gsparams
ret = self.tree_unflatten(aux_data, children)
return ret
def tree_flatten(self):
"""This function flattens the InterpolatedImage into a list of children
nodes that will be traced by JAX and auxiliary static data."""
return (self._jax_children, copy.copy(self._jax_aux_data))
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
val = {}
val.update(aux_data)
val.update(children[1])
ret = cls(children[0], **val)
return ret
def __getstate__(self):
return self.__dict__.copy()
def __setstate__(self, d):
self.__dict__ = d
@property
def x_interpolant(self):
"""The real-space `Interpolant` for this profile."""
return self._x_interpolant
@property
def k_interpolant(self):
"""The Fourier-space `Interpolant` for this profile."""
return self._k_interpolant
@property
def image(self):
"""The underlying `Image` being interpolated."""
return self._xim[self._image.bounds]
@property
def _flux(self):
return self._image_flux
@property
def _centroid(self):
x, y = self._pad_image.get_pixel_centers()
tot = jnp.sum(self._pad_image.array)
xpos = jnp.sum(x * self._pad_image.array) / tot
ypos = jnp.sum(y * self._pad_image.array) / tot
return PositionD(xpos, ypos)
@property
def _max_sb(self):
return jnp.max(jnp.abs(self._pad_image.array))
@property
def _flux_ratio(self):
if self._jax_children[1]["flux"] is None:
flux = self._image_flux
if self._jax_aux_data["normalization"].lower() in (
"surface brightness",
"sb",
):
flux *= self._wcs.pixelArea()
else:
flux = self._jax_children[1]["flux"]
# If the user specified a flux, then set the flux ratio for the transform that wraps
# this class
return flux / self._image_flux
@property
def _image_flux(self):
return jnp.sum(self._image.array, dtype=float)
@property
def _offset(self):
# Figure out the offset to apply based on the original image (not the padded one).
# We will apply this below in _sbp.
offset = self._parse_offset(self._jax_children[1]["offset"])
return self._adjust_offset(
self._image.bounds, offset, None, self._jax_aux_data["use_true_center"]
)
@property
def _image(self):
# Store the image as an attribute and make sure we don't change the original image
# in anything we do here. (e.g. set scale, etc.)
if self._jax_aux_data["depixelize"]:
# TODO: no depixelize in jax_galsim
# self._image = image.view(dtype=np.float64).depixelize(self._x_interpolant)
raise NotImplementedError(
"InterpolatedImages do not support 'depixelize' in jax_galsim."
)
else:
image = self._jax_children[0].copy(dtype=float)
if self._jax_aux_data["_recenter_image"]:
image.setCenter(0, 0)
return image
@property
def _wcs(self):
im_cen = (
self._jax_children[0].true_center
if self._jax_aux_data["use_true_center"]
else self._jax_children[0].center
)
# error checking was done on init
if self._jax_children[1]["scale"] is not None:
wcs = PixelScale(self._jax_children[1]["scale"])
elif self._jax_children[1]["wcs"] is not None:
wcs = self._jax_children[1]["wcs"]
else:
wcs = self._jax_children[0].wcs
return wcs.local(image_pos=im_cen)
@property
def _jac_arr(self):
image = self._jax_children[0]
im_cen = (
image.true_center if self._jax_aux_data["use_true_center"] else image.center
)
return self._wcs.jacobian(image_pos=im_cen).getMatrix().ravel()
@property
def _xim(self):
pad_factor = self._jax_aux_data["pad_factor"]
# The size of the final padded image is the largest of the various size specifications
pad_size = max(self._image.array.shape)
if pad_factor > 1.0:
pad_size = int(math.ceil(pad_factor * pad_size))
# And round up to a good fft size
pad_size = Image.good_fft_size(pad_size)
xim = Image(
_zeropad_image(
self._image.array, (pad_size - max(self._image.array.shape)) // 2
),
wcs=PixelScale(1.0),
)
xim.setCenter(0, 0)
# after the call to setCenter you get a WCS with an offset in
# it instead of a pure pixel scale
xim.wcs = PixelScale(1.0)
# Now place the given image in the center of the padding image:
xim[self._image.bounds] = self._image
return xim
@property
def _pad_image(self):
# These next two allow for easy pickling/repring. We don't need to serialize all the
# zeros around the edge. But we do need to keep any non-zero padding as a pad_image.
xim = self._xim
nz_bounds = self._image.bounds
return xim[nz_bounds]
@property
def _kim(self):
return self._xim.calculate_fft()
@property
def _maxk(self):
return self._getMaxK(self._jax_aux_data["calculate_maxk"])
@property
def _stepk(self):
return self._getStepK(self._jax_aux_data["calculate_stepk"])
def _getStepK(self, calculate_stepk):
# GalSim cannot automatically know what stepK and maxK are appropriate for the
# input image. So it is usually worth it to do a manual calculation (below).
if calculate_stepk:
if calculate_stepk is True:
im = self.image
else:
# If not a bool, then value is max_stepk
R = (jnp.ceil(jnp.pi / calculate_stepk)).astype(int)
b = BoundsI(xmin=-R, deltax=2 * R + 1, ymin=-R, deltay=2 * R + 1)
b = self.image.bounds & b
im = self.image[b]
thresh = (1.0 - self.gsparams.folding_threshold) * self._image_flux
# this line appears buggy in galsim - I expect they meant to use im
R = _calculate_size_containing_flux(im, thresh)
else:
R = max(*self.image.array.shape) / 2.0 - 0.5
return self._getSimpleStepK(R)
def _getSimpleStepK(self, R):
# Add xInterp range in quadrature just like convolution:
R2 = self._x_interpolant.xrange
R = jnp.hypot(R, R2)
stepk = jnp.pi / R
return stepk
def _getMaxK(self, calculate_maxk):
if calculate_maxk:
_uscale = 1 / (2 * jnp.pi)
_maxk = self._x_interpolant.urange() / _uscale
if calculate_maxk is True:
maxk = _find_maxk(
self._kim, _maxk, self._gsparams.maxk_threshold * self.flux
)
else:
maxk = _find_maxk(
self._kim, calculate_maxk, self._gsparams.maxk_threshold * self.flux
)
return maxk
else:
return self._x_interpolant.krange
def _xValue(self, pos):
x = jnp.array([pos.x], dtype=float)
y = jnp.array([pos.y], dtype=float)
return _xValue_arr(
x,
y,
self._offset.x,
self._offset.y,
self._pad_image.bounds.xmin,
self._pad_image.bounds.ymin,
self._pad_image.array,
self._x_interpolant,
)[0]
def _kValue(self, kpos):
kx = jnp.array([kpos.x], dtype=float)
ky = jnp.array([kpos.y], dtype=float)
return _kValue_arr(
kx,
ky,
self._offset.x,
self._offset.y,
self._kim.bounds.xmin,
self._kim.bounds.ymin,
self._kim.array,
self._kim.scale,
self._x_interpolant,
self._k_interpolant,
)[0]
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
jacobian = jnp.eye(2) if jac is None else jac
flux_scaling *= image.scale**2
# Create an array of coordinates
coords = jnp.stack(image.get_pixel_centers(), axis=-1)
coords = coords * image.scale # Scale by the image pixel scale
coords = coords - jnp.asarray(offset) # Add the offset
# Apply the jacobian transformation
inv_jacobian = jnp.linalg.inv(jacobian)
_, logdet = jnp.linalg.slogdet(inv_jacobian)
coords = jnp.dot(coords, inv_jacobian.T)
flux_scaling *= jnp.exp(logdet)
im = _xValue_arr(
coords[..., 0],
coords[..., 1],
self._offset.x,
self._offset.y,
self._pad_image.bounds.xmin,
self._pad_image.bounds.ymin,
self._pad_image.array,
self._x_interpolant,
)
# Apply the flux scaling
im *= flux_scaling
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
if jnp.issubdtype(im.dtype, jnp.floating) and jnp.issubdtype(
image.dtype, jnp.integer
):
im = jnp.around(im)
# Return an image
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
def _drawKImage(self, image, jac=None):
jacobian = jnp.eye(2) if jac is None else jac
# Create an array of coordinates
coords = jnp.stack(image.get_pixel_centers(), axis=-1)
coords = coords * image.scale # Scale by the image pixel scale
coords = jnp.dot(coords, jacobian)
im = _kValue_arr(
coords[..., 0],
coords[..., 1],
self._offset.x,
self._offset.y,
self._kim.bounds.xmin,
self._kim.bounds.ymin,
self._kim.array,
self._kim.scale,
self._x_interpolant,
self._k_interpolant,
)
im = im.astype(image.dtype)
# Return an image
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
@property
def _pos_neg_fluxes(self):
# record pos and neg fluxes now too
pflux = jnp.sum(jnp.where(self._pad_image.array > 0, self._pad_image.array, 0))
nflux = jnp.abs(
jnp.sum(jnp.where(self._pad_image.array < 0, self._pad_image.array, 0))
)
pint = self._x_interpolant.positive_flux
nint = self._x_interpolant.negative_flux
pint2d = pint * pint + nint * nint
nint2d = 2 * pint * nint
return [
pint2d * pflux + nint2d * nflux,
pint2d * nflux + nint2d * pflux,
]
@property
def _positive_flux(self):
return self._pos_neg_fluxes[0]
@property
def _negative_flux(self):
return self._pos_neg_fluxes[1]
def _flux_per_photon(self):
return self._calculate_flux_per_photon()
def _shoot(self, photons, rng):
# we first draw the index location from the image
img = self._pad_image
subkey = rng._state.split_one()
inds = jrng.choice(
subkey,
img.array.size,
shape=(photons.size(),),
replace=True,
# we use abs here since some of the pixels could be negative
# and for a noise image this procedure results in a fair
# sampling of the noise
p=jnp.abs(img.array.ravel()) / jnp.sum(jnp.abs(img.array)),
).astype(int)
yinds, xinds = jnp.unravel_index(inds, img.array.shape)
# the photons come from this index location
xcens = jnp.arange(0, img.bounds.deltax + 1) + img.bounds.xmin
ycens = jnp.arange(0, img.bounds.deltay + 1) + img.bounds.ymin
photons.x = xcens[xinds]
photons.y = ycens[yinds]
# this magic set of factors comes from the galsim C++ code in
# a few spots it is
#
# - the sign of the photon flux
# - the flux per photon = 1 - 2 neg / (pos + neg)
# - the total absolute flux in the image = (pos + neg)
# - the number of photons to draw = photons.size()
#
# If you unpack it all, then you get
#
# sign * (1 - 2 neg / (pos + neg)) * (pos + neg) / photons.size()
# = sign * (pos + neg - 2 neg) / (pos + neg) * (pos + neg) / photons.size()
# = sign * (pos - neg) / photons.size()
#
# So what we have is a sign that oscillates between -1 and 1 with each photon getting
# the flux of the object divided by the number of photons (which is inflated to get the total flux
# correct by other bits of the code)
photons.flux = (
jnp.sign(img.array.ravel())[inds]
* self._flux_per_photon()
* (self.positive_flux + self.negative_flux)
/ photons.size()
)
# account for offset - we add the offset to get to
# image pixels in the xValue method
# here we generate photons from the image and
# so we need to subtract it to get back to get to x as
# it would be input in xVal
photons.x -= self._offset.x
photons.y -= self._offset.y
# now we convolve with the x interpolant
x_photons = PhotonArray(photons.size())
self._x_interpolant._shoot(x_photons, rng)
photons.convolve(x_photons)
@implements(_galsim._InterpolatedImage)
def _InterpolatedImage(
image,
x_interpolant=Quintic(),
k_interpolant=Quintic(),
use_true_center=True,
offset=None,
gsparams=None,
force_stepk=0.0,
force_maxk=0.0,
):
return InterpolatedImage(
image,
x_interpolant=x_interpolant,
k_interpolant=k_interpolant,
use_true_center=use_true_center,
offset=offset,
gsparams=gsparams,
calculate_maxk=False,
calculate_stepk=False,
pad_factor=1.0,
flux=jnp.sum(image.array),
_force_stepk=force_stepk,
_force_maxk=force_maxk,
_recenter_image=False,
)
def _xValue_arr(x, y, x_offset, y_offset, xmin, ymin, arr, x_interpolant):
vals = _draw_with_interpolant_xval(
x + x_offset,
y + y_offset,
xmin,
ymin,
arr,
x_interpolant,
)
return vals
@partial(jax.vmap, in_axes=(0, None, None, None, None, None))
@partial(jax.jit, static_argnames=("interp",))
def _interp_weight_1d_xval(ioff, xi, xp, x, nx, interp):
xind = xi + ioff
mskx = (xind >= 0) & (xind < nx)
_x = x - (xp + ioff)
wx = interp._xval_noraise(_x)
wx = jnp.where(mskx, wx, 0)
return wx, xind.astype(jnp.int32)
@partial(jax.jit, static_argnames=("interp",))
def _draw_with_interpolant_xval(x, y, xmin, ymin, zp, interp):
"""This helper function interpolates an image (`zp`) with an interpolant `interp`
at the pixel locations given by `x`, `y`. The lower-left corner of the image is
`xmin` / `ymin`.
A more standard C/C++ code would have a set of nested for loops that iterates over each
location to interpolate and then over the nterpolation kernel.
In JAX, we instead write things such that the loop over the points to be interpolated
is vectorized in the code. We represent the loops over the interpolation kernel as explicit
for loops.
"""
# the vectorization over the interpolation points is easier to think about
# if they are in a 1D array. So we use ravel to flatten them and then reshape
# at the end.
orig_shape = x.shape
# the variables here are
# x/y: the x/y coordinates of the points to be interpolated
# xi/yi: the index of the nerest pixel below the point
# xp/yp: the x/y coordinate of the nearest pixel below the point
# nx/ny: the size of the x/y arrays
x = x.ravel()
xi = jnp.floor(x - xmin).astype(jnp.int32)
xp = xi + xmin
nx = zp.shape[1]
y = y.ravel()
yi = jnp.floor(y - ymin).astype(jnp.int32)
yp = yi + ymin
ny = zp.shape[0]
irange = interp.ixrange // 2
iinds = jnp.arange(-irange, irange + 1)
wx, xind = _interp_weight_1d_xval(
iinds,
xi,
xp,
x,
nx,
interp,
)
wy, yind = _interp_weight_1d_xval(
iinds,
yi,
yp,
y,
ny,
interp,
)
z = jnp.sum(
wx[None, :, :] * wy[:, None, :] * zp[yind[:, None, :], xind[None, :, :]],
axis=(0, 1),
)
# we reshape on the way out to match the input shape
return z.reshape(orig_shape)
def _kValue_arr(
kx,
ky,
x_offset,
y_offset,
kxmin,
kymin,
arr,
scale,
x_interpolant,
k_interpolant,
):
# phase factor due to offset
# not we shift by -offset which explains the sign
# in the exponent
pfac = jnp.exp(1j * (kx * x_offset + ky * y_offset))
kxi = kx / scale
kyi = ky / scale
_uscale = 1.0 / (2.0 * jnp.pi)
_maxk_xint = x_interpolant.urange() / _uscale / scale
# here we do the actual inteprolation in k space
val = _draw_with_interpolant_kval(
kxi,
kyi,
kymin, # this is not a bug! we need the minimum for the full periodic space
kymin,
arr,
k_interpolant,
)
# finally we multiply by the FFT of the real-space interpolation function
# and mask any values that are outside the range of the real-space interpolation
# FFT
msk = (jnp.abs(kxi) <= _maxk_xint) & (jnp.abs(kyi) <= _maxk_xint)
xint_val = x_interpolant._kval_noraise(kx) * x_interpolant._kval_noraise(ky)
return jnp.where(msk, val * xint_val * pfac, 0.0)
@partial(jax.vmap, in_axes=(0, None, None, None, None, None))
@partial(jax.jit, static_argnames=("interp",))
def _interp_weight_1d_kval(ioff, kxi, kxp, kx, nkx, interp):
kxind = (kxi + ioff) % nkx
_kx = kx - (kxp + ioff)
wkx = interp._xval_noraise(_kx)
return wkx, kxind.astype(jnp.int32)
@partial(jax.jit, static_argnames=("interp",))
def _draw_with_interpolant_kval(kx, ky, kxmin, kymin, zp, interp):
"""This function interpolates complex k-space images and follows the
same basic structure as _draw_with_interpolant_xval above.
The key difference is that the k-space images are Hermitian and so
only half of the data is actually in memory. We account for this by
computing all of the interpolation weights and indicies as if we had
the full image. Then finally, if we need a value that is not in memory,
we get it from the values we have via the Hermitian symmetry.
"""
# all of the code below is almost line-for-line the same as the
# _draw_with_interpolant_xval function above.
orig_shape = kx.shape
kx = kx.ravel()
kxi = jnp.floor(kx - kxmin).astype(jnp.int32)
kxp = kxi + kxmin
# this is the number of pixels in the half image and is needed
# for computing values via Hermition symmetry below
nkx_2 = zp.shape[1] - 1
nkx = nkx_2 * 2
ky = ky.ravel()
kyi = jnp.floor(ky - kymin).astype(jnp.int32)
kyp = kyi + kymin
nky = zp.shape[0]
irange = interp.ixrange // 2
iinds = jnp.arange(-irange, irange + 1)
wkx, kxind = _interp_weight_1d_kval(
iinds,
kxi,
kxp,
kx,
nkx,
interp,
)
wky, kyind = _interp_weight_1d_kval(
iinds,
kyi,
kyp,
ky,
nky,
interp,
)
wkx = wkx[None, :, :]
kxind = kxind[None, :, :]
wky = wky[:, None, :]
kyind = kyind[:, None, :]
# this is the key difference from the xval function
# we need to use the Hermitian symmetry to get the
# values that are not in memory
# in memory we have the values at nkx_2 to nkx - 1
# the Hermitian symmetry is that
# f(ky, kx) = conjugate(f(-kx, -ky))
# In indices this is a symmetric flip about the central
# pixels at kx = ky = 0.
# we do not need to mask any values that run off the edge of the image
# since we rewrap them using the periodicity of the image.
val = jnp.where(
kxind < nkx_2,
zp[(nky - kyind) % nky, nkx - kxind - nkx_2].conjugate(),
zp[kyind, kxind - nkx_2],
)
z = jnp.sum(
val * wkx * wky,
axis=(0, 1),
)
return z.reshape(orig_shape)
@jax.jit
def _flux_frac(a, x, y, cenx, ceny):
a = jnp.reshape(a, (a.shape[0], a.shape[1], 1))
dx = x - cenx
dx = jnp.reshape(dx, (a.shape[0], a.shape[1], 1))
dy = y - ceny
dy = jnp.reshape(dy, (a.shape[0], a.shape[1], 1))
d = jnp.arange(min(a.shape[0], a.shape[1]))
d = jnp.reshape(d, (1, 1, -1))
msk = (jnp.abs(dx) <= d) & (jnp.abs(dy) <= d)
res = jnp.sum(
jnp.where(
msk,
a,
0.0,
),
axis=(0, 1),
)
return res
@jax.jit
def _calculate_size_containing_flux(image, thresh):
cenx, ceny = image.center.x, image.center.y
x, y = image.get_pixel_centers()
fluxes = _flux_frac(image.array, x, y, cenx, ceny)
# we add 1 since the flux fraction computation above starts at
# one pixel and jnp.arange starts at zero
d = jnp.arange(min(image.array.shape[0], image.array.shape[1])) + 1.0
p = jnp.sign(thresh)
msk = (p * fluxes) >= (p * thresh)
return (
jnp.argmin(
jnp.where(
msk,
d,
jnp.inf,
)
)
+ 0.5
)
# this version doe snot match galsim's maxk operation exactly,
# but is faster to compute. I am leaving it here for
# posterity. - MRB
# @jax.jit
# def _inner_comp_find_maxk(arr, thresh, kx, ky):
# msk = (arr * arr.conjugate()).real > thresh * thresh
# max_kx = jnp.max(
# jnp.where(
# msk,
# jnp.abs(kx),
# -jnp.inf,
# )
# )
# max_ky = jnp.max(
# jnp.where(
# msk,
# jnp.abs(ky),
# -jnp.inf,
# )
# )
# # galsim adds one pixel at the end so that maxk is
# # the k value where things do not pass the threshold,
# # so we do that here too.
# return jnp.maximum(max_kx, max_ky) + kx[0, 1] - kx[0, 0]
@jax.jit
def _inner_comp_find_maxk_scan(arr, thresh, kx, ky):
val = (arr * arr.conjugate()).real
msk_thresh = val > thresh * thresh
akx = jnp.abs(kx)
aky = jnp.abs(ky)
def _func(carry, x):
msk_kx = akx <= x
msk_ky = aky <= x
return carry, jnp.sum(msk_thresh & msk_kx & msk_ky)
_, msk = jax.lax.scan(_func, None, xs=kx[0, :])
# We are searching for the location of the first string of
# five locations in a row in `msk` where the value stays the
# same.
# We do this by putting the array through jnp.diff, which
# computes the difference of adjacent elements. Then we convolve
# with a filter of ones of length five to sum groups of five
# elements together. The first location where the result is
# zero is the location we want. The tricky bit however is getting
# the indexing right.
# step 1. compute the diff of adjacent elements
# The function jnp.diff returns an array of size one less than
# the input. So we concatenate a zero at the front. This makes
# sense since if the original array is all constant, then the
# location of the first five zeros is at the start of the array.
delta_msk = jnp.concatenate(
[jnp.array([0], dtype=int), jnp.diff(msk)],
axis=0,
dtype=int,
)
# step 2. convolve with the filter
# In the discrete convolution, you have to deal with edge
# behavior where the filter only partially overlaps the arrays.
# We use the mode `full` which returns an array containing
# every possible combination with missing elements set to zero.
# We cut the first `length of filter - 1` elements so that
# index i of the result is the sum of the filter starting
# at index i of the input.
sums = jnp.convolve(delta_msk, jnp.ones(5, dtype=int), mode="full")[4:]
# step 3. find first location of zero in the convolution
# Finally, we use jnp.argmin to find the location of the first
# zero. Per the doc string, if there is more than one zero, this
# function returns the first location (i.e., smallest index)
# which is what we want.
msk_zero = sums == 0
sind, dk = jax.lax.cond(
jnp.any(msk_zero),
# if we find a set of zeros, the code computes the next pixel past
# the pixels where |kval| > thresh. So we set dk = 0 since we don't
# need to shift things.
lambda x: (jnp.argmin(jnp.where(x, 0, 1)), 0.0),
# if we get to the end of the array, we add one pixel spacing
# so we match galsim
lambda x: (-1, kx[0, -1] - kx[0, -2]),
msk_zero,
)
return kx[0, sind] + dk
@jax.jit
def _find_maxk(kim, max_maxk, thresh):
kx, ky = kim.get_pixel_centers()
kx *= kim.scale
ky *= kim.scale
# this minimum bounds the empirically determined
# maxk from the image (computed by _inner_comp_find_maxk)
# by max_maxk from above
return jnp.minimum(
_inner_comp_find_maxk_scan(kim.array, thresh, kx, ky),
max_maxk,
)