from collections import namedtuple
from functools import partial
import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
import jax_galsim.photon_array as pa
from jax_galsim.core.draw import calculate_n_photons
from jax_galsim.core.utils import implements, is_equal_with_arrays
from jax_galsim.errors import (
GalSimError,
GalSimIncompatibleValuesError,
GalSimNotImplementedError,
GalSimValueError,
galsim_warn,
)
from jax_galsim.gsparams import GSParams
from jax_galsim.photon_array import PhotonArray
from jax_galsim.position import Position, PositionD, PositionI
from jax_galsim.random import BaseDeviate
from jax_galsim.sensor import Sensor
from jax_galsim.utilities import parse_pos_args
[docs]
@implements(_galsim.GSObject)
class GSObject:
def __init__(self, *, gsparams=None, **params):
self._params = params # Dictionary containing all traced parameters
self._gsparams = GSParams.check(gsparams) # Non-traced static parameters
def __getstate__(self):
d = self.__dict__.copy()
d["had_workspace"] = "_workspace" in d
d.pop("_workspace", None)
return d
def __setstate__(self, d):
if d.pop("had_workspace", False):
d["_workspace"] = {}
self.__dict__ = d
@property
@implements(_galsim.GSObject.flux)
def flux(self):
return self._flux
@property
def _flux(self):
"""By default, the flux is contained in the parameters dictionay."""
return self._params["flux"]
@property
@implements(_galsim.GSObject.gsparams)
def gsparams(self):
return self._gsparams
@property
def params(self):
"""A Dictionary object containing all parameters of the internal represention of this object."""
return self._params
@property
@implements(_galsim.GSObject.maxk)
def maxk(self):
return self._maxk
@property
@implements(_galsim.GSObject.stepk)
def stepk(self):
return self._stepk
@property
@implements(_galsim.GSObject.nyquist_scale)
def nyquist_scale(self):
return jnp.pi / self.maxk
@property
@implements(_galsim.GSObject.has_hard_edges)
def has_hard_edges(self):
return self._has_hard_edges
@property
@implements(_galsim.GSObject.is_axisymmetric)
def is_axisymmetric(self):
return self._is_axisymmetric
@property
@implements(_galsim.GSObject.is_analytic_x)
def is_analytic_x(self):
return self._is_analytic_x
@property
@implements(_galsim.GSObject.is_analytic_k)
def is_analytic_k(self):
return self._is_analytic_k
@property
@implements(_galsim.GSObject.centroid)
def centroid(self):
return self._centroid
@property
def _centroid(self):
# Most profiles are centered at 0,0, so make this the default.
return PositionD(0, 0)
@property
@implements(_galsim.GSObject.positive_flux)
def positive_flux(self):
return self._positive_flux
@property
@implements(_galsim.GSObject.negative_flux)
def negative_flux(self):
return self._negative_flux
@property
def _positive_flux(self):
return self.flux + self._negative_flux
@property
def _negative_flux(self):
return 0.0
@property
def _flux_per_photon(self):
# The usual case.
return 1.0
def _calculate_flux_per_photon(self):
# If negative_flux is overriden, then _flux_per_photon should be overridden as well
# to return this calculation.
posflux = self.positive_flux
negflux = self.negative_flux
eta = negflux / (posflux + negflux)
return 1.0 - 2.0 * eta
@property
@implements(_galsim.GSObject.max_sb)
def max_sb(self):
return self._max_sb
@property
def _max_sb(self):
# The way this is used, overestimates are conservative.
# So the default value of 1.e500 will skip the optimization involving the maximum sb.
return 1.0e500
def __add__(self, other):
"""Add two GSObjects.
Equivalent to Add(self, other)
"""
from jax_galsim.sum import Sum
return Sum([self, other])
# op- is unusual, but allowed. It subtracts off one profile from another.
def __sub__(self, other):
"""Subtract two GSObjects.
Equivalent to Add(self, -1 * other)
"""
from .sum import Add
return Add([self, (-1.0 * other)])
# Make op* work to adjust the flux of an object
def __mul__(self, other):
"""Scale the flux of the object by the given factor.
obj * flux_ratio is equivalent to obj.withScaledFlux(flux_ratio)
It creates a new object that has the same profile as the original, but with the
surface brightness at every location scaled by the given amount.
You can also multiply by an `SED`, which will create a `ChromaticObject` where the `SED`
acts like a wavelength-dependent ``flux_ratio``.
"""
return self.withScaledFlux(other)
def __rmul__(self, other):
"""Equivalent to obj * other. See `__mul__` for details."""
return self.__mul__(other)
# Likewise for op/
def __div__(self, other):
"""Equivalent to obj * (1/other). See `__mul__` for details."""
return self * (1.0 / other)
__truediv__ = __div__
def __neg__(self):
return -1.0 * self
def __eq__(self, other):
return (self is other) or (
(type(other) is self.__class__)
and is_equal_with_arrays(self.tree_flatten(), other.tree_flatten())
)
[docs]
@implements(_galsim.GSObject.xValue)
def xValue(self, *args, **kwargs):
pos = parse_pos_args(args, kwargs, "x", "y")
return self._xValue(pos)
@implements(_galsim.GSObject._xValue)
def _xValue(self, pos):
raise NotImplementedError(
"%s does not implement xValue" % self.__class__.__name__
)
[docs]
@implements(_galsim.GSObject.kValue)
def kValue(self, *args, **kwargs):
kpos = parse_pos_args(args, kwargs, "kx", "ky")
return self._kValue(kpos)
@implements(_galsim.GSObject._kValue)
def _kValue(self, kpos):
raise NotImplementedError(
"%s does not implement kValue" % self.__class__.__name__
)
[docs]
@implements(_galsim.GSObject.withGSParams)
def withGSParams(self, gsparams=None, **kwargs):
if gsparams == self.gsparams:
return self
# Checking gsparams
gsparams = GSParams.check(gsparams, self.gsparams, **kwargs)
# Flattening the representation to instantiate a clean new object
children, aux_data = self.tree_flatten()
aux_data["gsparams"] = gsparams
return self.tree_unflatten(aux_data, children)
[docs]
@implements(_galsim.GSObject.withFlux)
def withFlux(self, flux):
return self.withScaledFlux(flux / self.flux)
[docs]
@implements(_galsim.GSObject.withScaledFlux)
def withScaledFlux(self, flux_ratio):
from jax_galsim.transform import Transform
return Transform(self, flux_ratio=flux_ratio)
[docs]
@implements(_galsim.GSObject.expand)
def expand(self, scale):
from jax_galsim.transform import Transform
return Transform(self, jac=[scale, 0.0, 0.0, scale])
[docs]
@implements(_galsim.GSObject.dilate)
def dilate(self, scale):
from jax_galsim.transform import Transform
# equivalent to self.expand(scale) * (1./scale**2)
return Transform(self, jac=[scale, 0.0, 0.0, scale], flux_ratio=scale**-2)
[docs]
@implements(_galsim.GSObject.magnify)
def magnify(self, mu):
return self.expand(jnp.sqrt(mu))
[docs]
@implements(_galsim.GSObject.shear)
def shear(self, *args, **kwargs):
from jax_galsim.shear import Shear
from jax_galsim.transform import Transform
if len(args) == 1:
shear = args[0]
if len(args) == 1:
if kwargs:
raise TypeError(
"Error, gave both unnamed and named arguments to GSObject.shear!"
)
if not isinstance(args[0], Shear):
raise TypeError(
"Error, unnamed argument to GSObject.shear is not a Shear!"
)
shear = args[0]
elif len(args) > 1:
raise TypeError("Error, too many unnamed arguments to GSObject.shear!")
elif len(kwargs) == 0:
raise TypeError("Error, shear argument is required")
else:
shear = Shear(**kwargs)
return Transform(self, shear.getMatrix())
@implements(_galsim.GSObject._shear)
def _shear(self, shear):
from jax_galsim.transform import Transform
return Transform(self, shear.getMatrix())
[docs]
@implements(_galsim.GSObject.lens)
def lens(self, g1, g2, mu):
from jax_galsim.shear import Shear
from jax_galsim.transform import Transform
shear = Shear(g1=g1, g2=g2)
return Transform(self, shear.getMatrix() * jnp.sqrt(mu))
@implements(_galsim.GSObject._lens)
def _lens(self, g1, g2, mu):
from .shear import _Shear
from .transform import Transform
shear = _Shear(g1 + 1j * g2)
return Transform(self, shear.getMatrix() * jnp.sqrt(mu))
[docs]
@implements(_galsim.GSObject.rotate)
def rotate(self, theta):
from jax_galsim.transform import Transform
from .angle import Angle
if not isinstance(theta, Angle):
raise TypeError("Input theta should be an Angle")
s, c = theta.sincos()
return Transform(self, jac=[c, -s, s, c])
[docs]
@implements(_galsim.GSObject.shift)
def shift(self, *args, **kwargs):
from jax_galsim.transform import Transform
offset = parse_pos_args(args, kwargs, "dx", "dy")
return Transform(self, offset=offset)
@implements(_galsim.GSObject._shift)
def _shift(self, dx, dy):
from jax_galsim.transform import Transform
new_obj = Transform(self, offset=(dx, dy))
return new_obj
# Make sure the image is defined with the right size and wcs for drawImage()
def _setup_image(
self, image, nx, ny, bounds, add_to_image, dtype, center, odd=False
):
from jax_galsim.bounds import BoundsI
from jax_galsim.image import Image
# If image is given, check validity of nx,ny,bounds:
if image is not None:
if bounds is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide bounds if image is provided",
bounds=bounds,
image=image,
)
if nx is not None or ny is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide nx,ny if image is provided",
nx=nx,
ny=ny,
image=image,
)
if dtype is not None and image.array.dtype != dtype:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot specify dtype != image.array.dtype if image is provided",
dtype=dtype,
image=image,
)
# Resize the given image if necessary
if not image.bounds.isDefined():
# Can't add to image if need to resize
if add_to_image:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot add_to_image if image bounds are not defined",
add_to_image=add_to_image,
image=image,
)
N = self.getGoodImageSize(1.0)
if odd:
N += 1
bounds = BoundsI(xmin=1, deltax=N, ymin=1, deltay=N)
image.resize(bounds)
# Else use the given image as is
# Otherwise, make a new image
else:
# Can't add to image if none is provided.
if add_to_image:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot add_to_image if image is None",
add_to_image=add_to_image,
image=image,
)
# Use bounds or nx,ny if provided
if bounds is not None:
if nx is not None or ny is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot set both bounds and (nx, ny)",
nx=nx,
ny=ny,
bounds=bounds,
)
if not bounds.isDefined():
raise _galsim.GalSimValueError(
"Cannot use undefined bounds", bounds
)
image = Image(bounds=bounds, dtype=dtype)
elif nx is not None or ny is not None:
if nx is None or ny is None:
raise _galsim.GalSimIncompatibleValuesError(
"Must set either both or neither of nx, ny", nx=nx, ny=ny
)
image = Image(nx, ny, dtype=dtype)
if center is not None:
# this code has to match the code in _get_new_bounds
# for the same branch of the if statement block
# if center, nx, and ny are given.
image.shift(
PositionI(
jnp.floor(center.x + 0.5 - image.true_center.x),
jnp.floor(center.y + 0.5 - image.true_center.y),
)
)
else:
N = self.getGoodImageSize(1.0)
if odd:
N += 1
image = Image(N, N, dtype=dtype)
if center is not None:
image.setCenter(PositionI(jnp.ceil(center.x), jnp.ceil(center.y)))
return image
def _local_wcs(self, wcs, image, offset, center, use_true_center, new_bounds):
# Get the local WCS at the location of the object.
if wcs.isUniform():
return wcs.local()
elif image is None:
bounds = new_bounds
else:
bounds = image.bounds
if not bounds.isDefined():
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide non-local wcs with automatically sized image",
wcs=wcs,
image=image,
bounds=new_bounds,
)
elif center is not None:
obj_cen = center
elif use_true_center:
obj_cen = bounds.true_center
else:
obj_cen = bounds.center
# Convert from PositionI to PositionD
obj_cen = PositionD(obj_cen.x, obj_cen.y)
# _parse_offset has already turned offset=None into PositionD(0,0), so it is safe to add.
obj_cen += offset
return wcs.local(image_pos=obj_cen)
def _parse_offset(self, offset):
if offset is None:
return PositionD(0, 0)
elif isinstance(offset, Position):
return PositionD(offset.x, offset.y)
else:
# Let python raise the appropriate exception if this isn't valid.
return PositionD(offset[0], offset[1])
def _parse_center(self, center):
# Almost the same as _parse_offset, except we leave it as None in that case.
if center is None:
return None
elif isinstance(center, Position):
return PositionD(center.x, center.y)
else:
# Let python raise the appropriate exception if this isn't valid.
return PositionD(center[0], center[1])
def _get_new_bounds(self, image, nx, ny, bounds, center):
from jax_galsim.bounds import BoundsI
if image is not None and image.bounds.isDefined():
return image.bounds
elif nx is not None and ny is not None:
b = BoundsI(xmin=1, deltax=nx, ymin=1, deltay=ny)
if center is not None:
# this code has to match the code in _setup_image
# for the same branch of the if statement block
# if center, nx and ny are given.
b = b.shift(
PositionI(
jnp.floor(center.x + 0.5 - b.true_center.x),
jnp.floor(center.y + 0.5 - b.true_center.y),
)
)
return b
elif bounds is not None and bounds.isDefined():
return bounds
else:
return BoundsI()
def _adjust_offset(self, new_bounds, offset, center, use_true_center):
# Note: this assumes self is in terms of image coordinates.
if center is not None:
if new_bounds.isDefined():
offset += center - new_bounds.center
else:
# Then will be created as even sized image.
offset += PositionD(
center.x - jnp.ceil(center.x), center.y - jnp.ceil(center.y)
)
elif use_true_center:
# For even-sized images, the SBProfile draw function centers the result in the
# pixel just up and right of the real center. So shift it back to make sure it really
# draws in the center.
# Also, remember that numpy's shape is ordered as [y,x]
dx = offset.x
dy = offset.y
shape = new_bounds.numpyShape()
dx -= 0.5 * ((shape[1] + 1) % 2)
dy -= 0.5 * ((shape[0] + 1) % 2)
# if shape[1] % 2 == 0: dx -= 0.5
# if shape[0] % 2 == 0: dy -= 0.5
offset = PositionD(dx, dy)
return offset
def _determine_wcs(self, scale, wcs, image, default_wcs=None):
from jax_galsim.wcs import BaseWCS, PixelScale
# Determine the correct wcs given the input scale, wcs and image.
if wcs is not None:
if scale is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide both wcs and scale", wcs=wcs, scale=scale
)
if not isinstance(wcs, BaseWCS):
raise TypeError("wcs must be a BaseWCS instance")
if image is not None:
image.wcs = None
elif scale is not None:
wcs = PixelScale(scale)
if image is not None:
image.wcs = None
elif image is not None and image.wcs is not None:
wcs = image.wcs
# If the input scale <= 0, or wcs is still None at this point, then use the Nyquist scale:
if wcs is None:
if default_wcs is None:
wcs = PixelScale(self.nyquist_scale)
else:
wcs = default_wcs
if wcs.isPixelScale() and wcs.isLocal():
wcs = jax.lax.cond(
wcs.scale <= 0,
lambda wcs, nqs: (
PixelScale(jnp.float_(nqs)) if default_wcs is None else default_wcs
),
lambda wcs, nqs: PixelScale(jnp.float_(wcs.scale)),
wcs,
self.nyquist_scale,
)
return wcs
[docs]
@implements(
_galsim.GSObject.drawImage,
lax_description="""\
The JAX-GalSim version of ``drawImage``
- does not do extensive (any?) checking of the input settings.
- uses a default of ``n_photons=None`` instead of ``n_photons=0``
to indicate that the number of photons should be determined
from the flux and gain
- requires that the ``maxN`` option be a constant since PhotonArrays are allocated
with ``maxN`` photons when this option is used and arrays in JAX must have static sizes.
""",
)
def drawImage(
self,
image=None,
nx=None,
ny=None,
bounds=None,
scale=None,
wcs=None,
dtype=None,
method="auto",
area=1.0,
exptime=1.0,
gain=1.0,
add_to_image=False,
center=None,
use_true_center=True,
offset=None,
n_photons=None,
rng=None,
max_extra_noise=0.0,
poisson_flux=None,
sensor=None,
photon_ops=(),
n_subsample=3,
maxN=None,
save_photons=False,
bandpass=None,
setup_only=False,
surface_ops=None,
):
from jax_galsim.box import Pixel
from jax_galsim.convolve import Convolution, Convolve
from jax_galsim.image import Image
from jax_galsim.wcs import PixelScale
if surface_ops is not None:
from .deprecated import depr
depr("surface_ops", 2.3, "photon_ops")
photon_ops = surface_ops
if image is not None and not isinstance(image, Image):
raise TypeError("image is not an Image instance", image)
if method == "phot" and save_photons and maxN is not None:
raise GalSimIncompatibleValuesError(
"Setting maxN is incompatible with save_photons=True"
)
if method not in ("auto", "fft", "real_space", "phot", "no_pixel", "sb"):
raise GalSimValueError(
"Invalid method name",
method,
("auto", "fft", "real_space", "phot", "no_pixel", "sb"),
)
# Check that the user isn't convolving by a Pixel already. This is almost always an error.
if method == "auto" and isinstance(self, Convolution):
if any([isinstance(obj, Pixel) for obj in self.obj_list]):
galsim_warn(
"You called drawImage with ``method='auto'`` "
"for an object that includes convolution by a Pixel. "
"This is probably an error. Normally, you should let GalSim "
"handle the Pixel convolution for you. If you want to handle the Pixel "
"convolution yourself, you can use method=no_pixel. Or if you really meant "
"for your profile to include the Pixel and also have GalSim convolve by "
"an _additional_ Pixel, you can suppress this warning by using method=fft."
)
if method != "phot":
if n_photons is not None:
raise GalSimIncompatibleValuesError(
"n_photons is only relevant for method='phot'",
method=method,
sensor=sensor,
n_photons=n_photons,
)
if poisson_flux is not None:
raise GalSimIncompatibleValuesError(
"poisson_flux is only relevant for method='phot'",
method=method,
sensor=sensor,
poisson_flux=poisson_flux,
)
if method != "phot" and sensor is None:
if rng is not None:
raise GalSimIncompatibleValuesError(
"rng is only relevant for method='phot' or when using a sensor",
method=method,
sensor=sensor,
rng=rng,
)
if maxN is not None:
raise GalSimIncompatibleValuesError(
"maxN is only relevant for method='phot' or when using a sensor",
method=method,
sensor=sensor,
maxN=maxN,
)
if save_photons:
raise GalSimIncompatibleValuesError(
"save_photons is only valid for method='phot' or when using a sensor",
method=method,
sensor=sensor,
save_photons=save_photons,
)
# Figure out what wcs we are going to use.
wcs = self._determine_wcs(scale, wcs, image)
# Make sure offset and center are PositionD, converting from other formats (tuple, array,..)
# Note: If None, offset is converted to PositionD(0,0), but center will remain None.
offset = self._parse_offset(offset)
center = self._parse_center(center)
# Determine the bounds of the new image for use below (if it can be known yet)
new_bounds = self._get_new_bounds(image, nx, ny, bounds, center)
# Get the local WCS, accounting for the offset correctly.
local_wcs = self._local_wcs(
wcs, image, offset, center, use_true_center, new_bounds
)
# Account for area and exptime.
flux_scale = area * exptime
# For surface brightness normalization, also scale by the pixel area.
if method == "sb":
flux_scale /= local_wcs.pixelArea()
# Only do the gain here if not photon shooting, since need the number of photons to
# reflect that actual photons, not ADU.
if method != "phot" and sensor is None:
flux_scale /= gain
# Determine the offset, and possibly fix the centering for even-sized images
offset = self._adjust_offset(new_bounds, offset, center, use_true_center)
# Convert the profile in world coordinates to the profile in image coordinates:
prof = local_wcs.profileToImage(self, flux_ratio=flux_scale, offset=offset)
local_wcs = local_wcs.shiftOrigin(offset)
# If necessary, convolve by the pixel
if method in ("auto", "fft", "real_space"):
if method == "auto":
real_space = None
elif method == "fft":
real_space = False
else:
real_space = True
prof = Convolve(
prof,
Pixel(scale=1.0, gsparams=self.gsparams),
real_space=real_space,
gsparams=self.gsparams,
)
# Make sure image is setup correctly
image = prof._setup_image(image, nx, ny, bounds, add_to_image, dtype, center)
image_in = (
image # For compatibility with normal galsim, we update image_in below.
)
image.wcs = wcs
if setup_only:
image.added_flux = 0.0
return image
# Making a view of the image lets us change the center without messing up the original.
original_center = image.center
wcs = image.wcs
image.setCenter(0, 0)
image.wcs = PixelScale(1.0)
if method == "phot":
added_photons, photons = prof.drawPhot(
image,
gain,
add_to_image,
n_photons,
rng,
max_extra_noise,
poisson_flux,
sensor,
photon_ops,
maxN,
original_center,
local_wcs,
)
else:
if sensor is not None or photon_ops:
raise NotImplementedError(
"Sensor/photon_ops not yet implemented in drawImage for method != 'phot'."
)
if prof.is_analytic_x:
added_photons = prof.drawReal(image, add_to_image)
else:
added_photons = prof.drawFFT(image, add_to_image)
image.added_flux = added_photons / flux_scale
# Restore the original center and wcs
image.shift(original_center)
image.wcs = wcs
if save_photons:
image.photons = photons
# Update image_in to satisfy GalSim API
image_in._array = image._array
image_in.added_flux = image.added_flux
image_in._bounds = image._bounds
image_in.wcs = image.wcs
image_in._dtype = image._dtype
if save_photons:
image_in.photons = photons
return image
[docs]
@implements(_galsim.GSObject.drawReal)
def drawReal(self, image, add_to_image=False):
if image.wcs is None or not image.wcs.isPixelScale():
raise _galsim.GalSimValueError(
"drawReal requires an image with a PixelScale wcs", image
)
im1 = self._drawReal(image)
temp = im1.subImage(image.bounds)
if jnp.issubdtype(temp.array.dtype, jnp.floating) and jnp.issubdtype(
image.array.dtype, jnp.integer
):
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
temp.array = jnp.around(temp.array)
if add_to_image:
image._array = image._array.at[...].add(temp._array)
else:
image._array = temp._array
return temp.array.sum(dtype=float)
@implements(_galsim.GSObject._drawReal)
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
raise NotImplementedError(
"%s does not implement drawReal" % self.__class__.__name__
)
[docs]
@implements(_galsim.GSObject.getGoodImageSize)
def getGoodImageSize(self, pixel_scale):
# Start with a good size from stepk and the pixel scale
Nd = 2.0 * jnp.pi / (pixel_scale * self.stepk)
# Make it an integer
# (Some slop to keep from getting extra pixels due to roundoff errors in calculations.)
N = jnp.ceil(Nd * (1.0 - 1.0e-12)).astype(int)
# Round up to an even value
N = 2 * ((N + 1) // 2)
return N
[docs]
@implements(_galsim.GSObject.drawFFT_makeKImage)
def drawFFT_makeKImage(self, image):
from jax_galsim.bounds import BoundsI
from jax_galsim.image import ImageCD, ImageCF
# Before any computations, let's check if we actually have a choice based on the gsparams.
if self.gsparams.maximum_fft_size == self.gsparams.minimum_fft_size:
with jax.ensure_compile_time_eval():
Nk = self.gsparams.maximum_fft_size
N = Nk
dk = 2.0 * np.pi / (N * image.scale)
else:
# Start with what this profile thinks a good size would be given the image's pixel scale.
N = self.getGoodImageSize(image.scale)
# We must make something big enough to cover the target image size:
image_N = jnp.max(
jnp.array(
[
jnp.max(
jnp.abs(
jnp.array(
[image.xmin, image.xmax, image.ymin, image.ymax]
)
)
)
* 2,
jnp.max(jnp.array(image.bounds.numpyShape())),
]
)
)
N = jnp.max(jnp.array([N, image_N]))
# Round up to a good size for making FFTs:
N = image.good_fft_size(N)
# Make sure we hit the minimum size specified in the gsparams.
N = max(N, self.gsparams.minimum_fft_size)
dk = 2.0 * jnp.pi / (N * image.scale)
maxk = self.maxk
if N * dk / 2 > maxk:
Nk = N
else:
# There will be aliasing. Make a larger image and then wrap it.
Nk = int(jnp.ceil(maxk / dk)) * 2
if Nk > self.gsparams.maximum_fft_size:
raise _galsim.GalSimFFTSizeError(
"drawFFT requires an FFT that is too large.", Nk
)
bounds = BoundsI(
xmin=0, deltax=Nk // 2 + 1, ymin=-Nk // 2, deltay=2 * (Nk // 2) + 1
)
if image.dtype in (np.complex128, np.float64, np.int32, np.uint32):
kimage = ImageCD(bounds=bounds, scale=dk)
else:
kimage = ImageCF(bounds=bounds, scale=dk)
return kimage, N
[docs]
@implements(_galsim.GSObject.drawFFT_finish)
def drawFFT_finish(self, image, kimage, wrap_size, add_to_image):
from jax_galsim.bounds import BoundsI
from jax_galsim.image import Image
# Wrap the full image to the size we want for the FT.
# Even if N == Nk, this is useful to make this portion properly Hermitian in the
# N/2 column and N/2 row.
bwrap = BoundsI(
xmin=0,
deltax=wrap_size // 2 + 1,
ymin=-wrap_size // 2,
deltay=2 * (wrap_size // 2),
)
kimage_wrap = kimage._wrap(bwrap, True, False, wrap_size)
# Perform the fourier transform.
breal = BoundsI(
xmin=-wrap_size // 2,
deltax=2 * (wrap_size // 2),
ymin=-wrap_size // 2,
deltay=2 * (wrap_size // 2),
)
kimg_shift = jnp.fft.ifftshift(kimage_wrap.array, axes=(-2,))
real_image_arr = jnp.fft.fftshift(
jnp.fft.irfft2(kimg_shift, breal.numpyShape())
)
real_image = Image(
bounds=breal, array=real_image_arr, dtype=image.dtype, wcs=image.wcs
)
if jnp.issubdtype(real_image.array.dtype, jnp.floating) and jnp.issubdtype(
image.array.dtype, jnp.integer
):
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
real_image.array = jnp.around(real_image.array)
# Add (a portion of) this to the original image.
temp = real_image.subImage(image.bounds)
if add_to_image:
image._array = image._array.at[...].add(temp._array)
else:
image._array = temp._array
return temp.array.sum(dtype=float)
[docs]
@implements(_galsim.GSObject.drawFFT)
def drawFFT(self, image, add_to_image=False):
if image.wcs is None or not image.wcs.isPixelScale():
raise _galsim.GalSimValueError(
"drawFFT requires an image with a PixelScale wcs", image
)
kimage, wrap_size = self.drawFFT_makeKImage(image)
kimage = self._drawKImage(kimage)
return self.drawFFT_finish(image, kimage, wrap_size, add_to_image)
[docs]
@implements(_galsim.GSObject.drawKImage)
def drawKImage(
self,
image=None,
nx=None,
ny=None,
bounds=None,
scale=None,
add_to_image=False,
recenter=True,
bandpass=None,
setup_only=False,
):
from jax_galsim.image import Image
from jax_galsim.wcs import PixelScale
# Make sure provided image is complex
if image is not None:
if not isinstance(image, Image):
raise TypeError("Provided image must be galsim.Image", image)
if not image.iscomplex:
raise _galsim.GalSimValueError("Provided image must be complex", image)
# Possibly get the scale from image.
if image is not None and scale is None:
# Grab the scale to use from the image.
# This will raise a TypeError if image.wcs is not a PixelScale
scale = image.scale
# The input scale (via scale or image.scale) is really a dk value, so call it that for
# clarity here, since we also need the real-space pixel scale, which we will call dx.
if scale is None or scale <= 0:
dk = self.stepk
else:
dk = scale
if image is not None and image.bounds.isDefined():
dx = np.pi / (max(image.array.shape) // 2 * dk)
elif scale is None or scale <= 0:
dx = self.nyquist_scale
else:
# Then dk = scale, which implies that we need to have dx smaller than nyquist_scale
# by a factor of (dk/stepk)
dx = self.nyquist_scale * dk / self.stepk
# If the profile needs to be constructed from scratch, the _setup_image function will
# do that, but only if the profile is in image coordinates for the real space image.
# So make that profile.
if image is None or not image.bounds.isDefined():
real_prof = PixelScale(dx).profileToImage(self)
dtype = np.complex128 if image is None else image.dtype
image = real_prof._setup_image(
image, nx, ny, bounds, add_to_image, dtype, center=None, odd=True
)
else:
# Do some checks that setup_image would have done for us
if bounds is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide bounds if image is provided",
bounds=bounds,
image=image,
)
if nx is not None or ny is not None:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot provide nx,ny if image is provided",
nx=nx,
ny=ny,
image=image,
)
# Can't both recenter a provided image and add to it.
if recenter and image.center != PositionI(0, 0) and add_to_image:
raise _galsim.GalSimIncompatibleValuesError(
"Cannot use add_to_image=True unless image is centered at (0,0) or recenter=False",
recenter=recenter,
image=image,
add_to_image=add_to_image,
)
# Set the center to 0,0 if appropriate
if recenter:
image._shift(-image.center)
# Set the wcs of the images to use the dk scale size
image.scale = dk
if setup_only:
return image
# For GalSim compatibility, we will attempt to update the input image
image_in = image
im2 = Image(bounds=image.bounds, dtype=image.dtype, scale=image.scale)
im2 = self._drawKImage(im2)
if not add_to_image:
image._array = im2._array
else:
image._array = image._array.at[...].add(im2._array)
image_in._array = image._array
image_in._bounds = image._bounds
image_in.wcs = image.wcs
image_in._dtype = image._dtype
return image
@implements(_galsim.GSObject._drawKImage)
def _drawKImage(
self, image, jac=None
): # pragma: no cover (all our classes override this)
raise NotImplementedError(
"%s does not implement drawKImage" % self.__class__.__name__
)
@implements(_galsim.GSObject._calculate_nphotons)
def _calculate_nphotons(self, n_photons, poisson_flux, max_extra_noise, rng):
n_photons, g, _rng = calculate_n_photons(
self.flux,
self._flux_per_photon,
self.max_sb,
n_photons=n_photons,
rng=rng,
max_extra_noise=max_extra_noise,
poisson_flux=poisson_flux,
)
if rng is not None:
rng._state = _rng._state
return n_photons, g
[docs]
@implements(
_galsim.GSObject.makePhot,
lax_description="""\
The JAX-GalSim version of ``makePhot``
- does little to no error checking on the inputs
- uses a default of ``n_photons=None`` instead of ``n_photons=0``
to indicate that the number of photons should be determined
from the flux and gain
""",
)
def makePhot(
self,
n_photons=None,
rng=None,
max_extra_noise=0.0,
poisson_flux=None,
photon_ops=(),
local_wcs=None,
surface_ops=None,
):
if surface_ops is not None:
from .deprecated import depr
depr("surface_ops", 2.3, "photon_ops")
photon_ops = surface_ops
if poisson_flux is None:
# If n_photons is given, poisson_flux = False
poisson_flux = n_photons is None
if n_photons is not None:
# n_photons is the length of an array so it is a python int and
# and thus a constant wrt to JIT
Ntot = int(n_photons + 0.5)
_, g = self._calculate_nphotons(
n_photons, poisson_flux, max_extra_noise, rng
)
else:
# here Ntot can be a traced value
# one thus must use the fixed_photon_array_size context manager
# to ensure that the size of the photon array is fixed if using JIT
Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng)
try:
photons = self.shoot(Ntot, rng)
except (GalSimError, NotImplementedError) as e:
raise GalSimNotImplementedError(
"Unable to draw this GSObject with photon shooting. Perhaps it "
"is a Deconvolve or is a compound including one or more "
"Deconvolve objects.\nOriginal error: %r" % (e)
)
# jax.lax.cond doesn't evaluate both of the branches
# and this call can save computations for common cases.
photons = jax.lax.cond(
g == 1.0,
lambda photons, g: photons,
lambda photons, g: photons.scaleFlux(g),
photons,
g,
)
for op in photon_ops:
op.applyTo(photons, local_wcs, rng)
return photons
[docs]
@implements(
_galsim.GSObject.drawPhot,
lax_description="""\
The JAX-GalSim version of ``drawPhot``
- does little to no error checking on the inputs
- uses a default of ``n_photons=None`` instead of ``n_photons=0``
to indicate that the number of photons should be determined
from the flux and gain
- requires that the ``maxN`` option must be a constant
""",
)
def drawPhot(
self,
image,
gain=1.0,
add_to_image=False,
n_photons=None,
rng=None,
max_extra_noise=0.0,
poisson_flux=None,
sensor=None,
photon_ops=(),
maxN=None,
orig_center=PositionI(0, 0),
local_wcs=None,
surface_ops=None,
):
if surface_ops is not None:
from .deprecated import depr
depr("surface_ops", 2.3, "photon_ops")
photon_ops = surface_ops
# If n_photons is given and poisson_flux is None, poisson_flux = False
if poisson_flux is None:
poisson_flux = n_photons is None
# Make sure the image is set up to have unit pixel scale and centered at 0,0.
if image.wcs is None or not image.wcs._isPixelScale:
raise GalSimValueError(
"drawPhot requires an image with a PixelScale wcs", image
)
if sensor is None:
sensor = Sensor()
elif not isinstance(sensor, Sensor):
raise TypeError("The sensor provided is not a Sensor instance")
if n_photons is not None:
# n_photons is the length of an array so it is a python int and
# and thus a constant wrt to JIT
Ntot = int(n_photons + 0.5)
_, g = self._calculate_nphotons(
n_photons, poisson_flux, max_extra_noise, rng
)
else:
# here Ntot can be a traced value
# one thus must use the fixed_photon_array_size context manager
# or the maxN option to ensure that the size of the photon array is fixed if using JIT
Ntot, g = self._calculate_nphotons(0.0, poisson_flux, max_extra_noise, rng)
# this call can save computations for the
# common case of gain == 1.0
g = jax.lax.cond(
gain != 1.0,
lambda g, gain: g / gain,
lambda g, gain: g,
g,
gain,
)
if not add_to_image:
image.setZero()
# both maxN and _JAX_GALSIM_PHOTON_ARRAY_SIZE can be used to fix the sizes
# of the photon arrays for use with JIT
if maxN is not None and pa._JAX_GALSIM_PHOTON_ARRAY_SIZE is not None:
# if both maxN and _JAX_GALSIM_PHOTON_ARRAY_SIZE are set, we use the smaller
# of the two
maxN = min(maxN, pa._JAX_GALSIM_PHOTON_ARRAY_SIZE)
else:
# otherwise we use the one that is set
maxN = pa._JAX_GALSIM_PHOTON_ARRAY_SIZE or maxN
if maxN is None:
# if neither maxN nor _JAX_GALSIM_PHOTON_ARRAY_SIZE are set
# we drae Ntot photons all at once
_dfret = _draw_phot_while_loop_shoot(
maxN=Ntot,
thisN=Ntot,
Ntot=Ntot,
obj=self,
rng=rng,
g=g,
image=image,
photon_ops=photon_ops,
sensor=sensor,
orig_center=orig_center,
local_wcs=local_wcs,
resume=False,
added_flux=0.0,
)
else:
# if maxN or _JAX_GALSIM_PHOTON_ARRAY_SIZE is set
# we draw a fixed number of photons at a time in a while
# loop until we have drawn Ntot photons
_dfret = _draw_phot_while_loop(
photons=PhotonArray(maxN),
rng=rng,
obj=self,
image=image,
g=g,
Ntot=Ntot,
maxN=maxN,
photon_ops=photon_ops,
local_wcs=local_wcs,
sensor=sensor,
orig_center=orig_center,
)
if rng is not None:
rng._state = _dfret.rng._state
else:
rng = _dfret.rng
for i in range(len(photon_ops)):
photon_ops[i] = _dfret.photon_ops[i]
image._array = _dfret.image._array
# TODO: how to update the sensor?
# https://github.com/GalSim-developers/JAX-GalSim/issues/85
if sensor.__class__ is not Sensor:
raise GalSimNotImplementedError(
"Non-default sensors that carry state are not yet supported in jax-galsim."
)
return _dfret.added_flux, _dfret.photons
[docs]
@implements(_galsim.GSObject.shoot)
def shoot(self, n_photons, rng=None):
photons = pa.PhotonArray(n_photons)
if photons.x.shape[0] > 0:
_rng = BaseDeviate(rng)
self._shoot(photons, _rng)
if rng is not None:
rng._state = _rng._state
return photons
@implements(_galsim.GSObject._shoot)
def _shoot(self, photons, rng):
raise NotImplementedError(
"%s does not implement shoot" % self.__class__.__name__
)
[docs]
@implements(_galsim.GSObject.applyTo)
def applyTo(self, photon_array, local_wcs=None, rng=None):
# galsim does not deal with dxdz and dydz here - IDK why
p1 = pa.PhotonArray(len(photon_array))
p1._wave = jax.lax.cond(
photon_array.hasAllocatedWavelengths(),
lambda pa_wave, p1_wave: pa_wave,
lambda pa_wave, p1_wave: p1_wave,
photon_array._wave,
p1._wave,
)
p1._pupil_u, p1._pupil_v = jax.lax.cond(
photon_array.hasAllocatedPupil(),
lambda pa_u, pa_v, p1_u, p1_v: (pa_u, pa_v),
lambda pa_u, pa_v, p1_u, p1_v: (p1_u, p1_v),
photon_array._pupil_u,
photon_array._pupil_v,
p1._pupil_u,
p1._pupil_v,
)
p1._time = jax.lax.cond(
photon_array.hasAllocatedTimes(),
lambda pa_time, p1_time: pa_time,
lambda pa_time, p1_time: p1_time,
photon_array._time,
p1._time,
)
obj = local_wcs.toImage(self) if local_wcs is not None else self
obj._shoot(p1, rng)
photon_array.convolve(p1, rng)
[docs]
def tree_flatten(self):
"""This function flattens the GSObject into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self.params,)
# Define auxiliary static data that doesn’t need to be traced
aux_data = {"gsparams": self.gsparams}
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
return cls(**(children[0]), **aux_data)
_DrawPhotReturnTuple = namedtuple(
"_DrawPhotReturnTuple",
[
"photons",
"rng",
"added_flux",
"image",
"photon_ops",
"sensor",
"resume",
],
)
def _draw_phot_while_loop_shoot(
*,
maxN,
thisN,
Ntot,
obj,
rng,
g,
image,
photon_ops,
sensor,
orig_center,
local_wcs,
resume,
added_flux,
Nleft=0,
photons=None,
):
"""This helper function shoots thisN photons and accumulates them into the image."""
try:
photons = obj.shoot(maxN, rng)
except (GalSimError, NotImplementedError) as e:
raise GalSimNotImplementedError(
"Unable to draw this GSObject with photon shooting. Perhaps it "
"is a Deconvolve or is a compound including one or more "
"Deconvolve objects.\nOriginal error: %r" % (e)
)
# we drew maxN, but only keep thisN of them
photons._num_keep = thisN
photons = jax.lax.cond(
# weird way to say gain == 1 and thisN == Ntot
jnp.abs(g - 1.0) + jnp.abs(thisN - Ntot) == 0,
lambda photons, g, thisN, Ntot: photons,
# the factor here is thisN / Ntot since we drew thisN photons, but use a total of Ntot photons
lambda photons, g, thisN, Ntot: photons.scaleFlux(g * thisN / Ntot),
photons,
g,
thisN,
Ntot,
)
photons = jax.lax.cond(
image.scale != 1.0,
lambda photons, scale: photons.scaleXY(
1.0 / scale
), # Convert x,y to image coords if necessary
lambda photons, scale: photons,
photons,
image.scale,
)
for op in photon_ops:
op.applyTo(photons, local_wcs, rng)
if image.dtype in (jnp.float32, jnp.float64):
added_flux += sensor.accumulate(photons, image, orig_center, resume=resume)
resume = True # Resume from this point if there are any further iterations.
else:
# Need a temporary
from jax_galsim.image import ImageD
im1 = ImageD(bounds=image.bounds)
added_flux += sensor.accumulate(photons, im1, orig_center)
image += im1
return _DrawPhotReturnTuple(
photons, rng, added_flux, image, photon_ops, sensor, resume
)
@partial(jax.jit, static_argnames=("maxN",))
def _draw_phot_while_loop(
*,
photons,
rng,
obj,
image,
g,
Ntot,
maxN,
photon_ops,
local_wcs,
sensor,
orig_center,
):
"""This helper function shoots photons until Ntot is reached."""
def _cond_fun(kwargs):
return kwargs["Nleft"] > 0
def _body_fun(kwargs):
# Shoot at most maxN at a time
thisN = jnp.minimum(maxN, kwargs["Nleft"])
_dfret = _draw_phot_while_loop_shoot(maxN=maxN, thisN=thisN, **kwargs)
return dict(
photons=_dfret.photons,
rng=_dfret.rng,
added_flux=_dfret.added_flux,
obj=kwargs["obj"],
Nleft=kwargs["Nleft"] - thisN,
resume=_dfret.resume,
image=_dfret.image,
g=kwargs["g"],
photon_ops=_dfret.photon_ops,
local_wcs=kwargs["local_wcs"],
sensor=_dfret.sensor,
orig_center=kwargs["orig_center"],
Ntot=kwargs["Ntot"],
)
ret_kwargs = jax.lax.while_loop(
_cond_fun,
_body_fun,
dict(
photons=photons,
rng=BaseDeviate(rng),
added_flux=jnp.array(0),
obj=obj,
Nleft=jnp.array(Ntot),
resume=jnp.array(False),
image=image,
g=g,
photon_ops=photon_ops,
local_wcs=local_wcs,
sensor=sensor,
orig_center=orig_center,
Ntot=Ntot,
),
)
return _DrawPhotReturnTuple(
ret_kwargs["photons"],
ret_kwargs["rng"],
ret_kwargs["added_flux"],
ret_kwargs["image"],
ret_kwargs["photon_ops"],
ret_kwargs["sensor"],
False,
)