from contextlib import contextmanager
import galsim as _galsim
import jax
import jax.numpy as jnp
import jax.random as jrng
from jax.tree_util import register_pytree_node_class
from jax_galsim.core.utils import (
cast_numpy_array_to_native_byte_order,
cast_to_python_int,
implements,
)
from jax_galsim.errors import (
GalSimIncompatibleValuesError,
GalSimRangeError,
GalSimUndefinedBoundsError,
GalSimValueError,
)
from jax_galsim.random import BaseDeviate, UniformDeviate
from ._pyfits import pyfits
_JAX_GALSIM_PHOTON_ARRAY_SIZE = None
[docs]
@contextmanager
def fixed_photon_array_size(size):
"""Context manager to temporarily set a fixed size for photon arrays."""
global _JAX_GALSIM_PHOTON_ARRAY_SIZE
old_size = _JAX_GALSIM_PHOTON_ARRAY_SIZE
_JAX_GALSIM_PHOTON_ARRAY_SIZE = size
try:
yield
finally:
_JAX_GALSIM_PHOTON_ARRAY_SIZE = old_size
[docs]
@implements(
_galsim.PhotonArray,
lax_description="""\
JAX-GalSim PhotonArrays have significant differences from the original GalSim.
- They always copy input data and operations on them always copy.
- They (usually) do not do any type/size checking on input data.
- They do not support indexed assignement directly on the attributes.
- The additional properties ``dxdz``, ``dydz``, ``wavelength``, ``pupil_u``, ``pupil_v``,
and ``time`` are set to arrays of NaNs by default. They are thus always allocated.
However, the methods like `hasAllocatedAngles` etc. return false if the arrays
are all NaNs.
Further, a context manager ``fixed_photon_array_size`` is provided to temporarily
set a fixed size for photon arrays.
- This functionality is useful when applying JIT to operations that vary the
number of photons drawn using Poisson statistics.
- When using this context manager, the attribute ``_nokeep`` stores a boolean mask
indicating which photons are to be kept.
- The attribute ``_num_keep`` stores the number of photons to be kept. If you set
this attribute, the ``_nokeep`` mask is updated by sorting ``_nokeep`` so that things
to be kept are at the start, the first ``_num_keep`` photons are marked to be kept,
and finally the array is sorted back to its original order.
- You may get an error if you ask for more photons than the fixed size, but not always,
especially in JITed code.
- Operations on photon arrays with fixed sizes but different `_num_keep` values are not
defined and will not raise an error.
- The ``.flux`` property scales ``._flux`` by the ratio of the fixed size to the number
of kept photons and sets non-kept photons to zero flux. Setting ``.flux`` to ``._flux``
will break things badly.
- Profiles should always draw the full number of photons given by ``.size()`` or ``len()``
so that they use fixed array sizes and things are JIT compatible.
**The ``_nokeep``, ``_num_keep``, and associated methods are private and should not be set by hand
unless you know what you are doing!**
""",
)
@register_pytree_node_class
class PhotonArray:
def __init__(
self,
N,
x=None,
y=None,
flux=None,
dxdz=None,
dydz=None,
wavelength=None,
pupil_u=None,
pupil_v=None,
time=None,
_nokeep=None,
):
self._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or N
if _JAX_GALSIM_PHOTON_ARRAY_SIZE is not None:
try:
# this will raise a boolean conversion error in JAX
# which we swallow
err_cond = (N > _JAX_GALSIM_PHOTON_ARRAY_SIZE) or False
except Exception:
err_cond = False
if err_cond:
raise GalSimValueError(
f"The given photon array size {N} is larger than "
f"the allowed total size {_JAX_GALSIM_PHOTON_ARRAY_SIZE}."
)
if _nokeep is not None:
self._nokeep = _nokeep
else:
self._nokeep = jnp.arange(self._Ntot) >= N
# Only x, y, flux are built by default, since these are always required.
# The others we leave as None unless/until they are needed.
self._x = jnp.zeros(self._Ntot, dtype=float)
self._y = jnp.zeros(self._Ntot, dtype=float)
self._flux = jnp.zeros(self._Ntot, dtype=float)
self._dxdz = jnp.full(self._Ntot, jnp.nan, dtype=float)
self._dydz = jnp.full(self._Ntot, jnp.nan, dtype=float)
self._wave = jnp.full(self._Ntot, jnp.nan, dtype=float)
self._pupil_u = jnp.full(self._Ntot, jnp.nan, dtype=float)
self._pupil_v = jnp.full(self._Ntot, jnp.nan, dtype=float)
self._time = jnp.full(self._Ntot, jnp.nan, dtype=float)
self._is_corr = jnp.array(False)
if x is not None:
self.x = x
if y is not None:
self.y = y
if flux is not None:
self.flux = flux
if dxdz is not None:
self.dxdz = dxdz
if dydz is not None:
self.dydz = dydz
if wavelength is not None:
self.wavelength = wavelength
if pupil_u is not None:
self.pupil_u = pupil_u
if pupil_v is not None:
self.pupil_v = pupil_v
if time is not None:
self.time = time
[docs]
@classmethod
@implements(
_galsim.PhotonArray.fromArrays,
lax_description="JAX-GalSim does not do input type/size checking.",
)
def fromArrays(
cls,
x,
y,
flux,
dxdz=None,
dydz=None,
wavelength=None,
pupil_u=None,
pupil_v=None,
time=None,
is_corr=False,
):
return cls._fromArrays(
x, y, flux, dxdz, dydz, wavelength, pupil_u, pupil_v, time, is_corr
)
@classmethod
@implements(_galsim.PhotonArray._fromArrays)
def _fromArrays(
cls,
x,
y,
flux,
dxdz=None,
dydz=None,
wavelength=None,
pupil_u=None,
pupil_v=None,
time=None,
is_corr=False,
):
if (
_JAX_GALSIM_PHOTON_ARRAY_SIZE is not None
and x.shape[0] != _JAX_GALSIM_PHOTON_ARRAY_SIZE
):
raise GalSimValueError(
"The given arrays do not match the expected total size",
x.shape[0],
_JAX_GALSIM_PHOTON_ARRAY_SIZE,
)
ret = cls.__new__(cls)
ret._Ntot = _JAX_GALSIM_PHOTON_ARRAY_SIZE or x.shape[0]
ret._x = x.copy()
ret._y = y.copy()
ret._flux = flux.copy()
ret._nokeep = jnp.arange(ret._Ntot) >= x.shape[0]
ret._dxdz = (
dxdz.copy()
if dxdz is not None
else jnp.full(ret._Ntot, jnp.nan, dtype=float)
)
ret._dydz = (
dydz.copy()
if dydz is not None
else jnp.full(ret._Ntot, jnp.nan, dtype=float)
)
ret._wave = (
wavelength.copy()
if wavelength is not None
else jnp.full(ret._Ntot, jnp.nan, dtype=float)
)
ret._pupil_u = (
pupil_u.copy()
if pupil_u is not None
else jnp.full(ret._Ntot, jnp.nan, dtype=float)
)
ret._pupil_v = (
pupil_v.copy()
if pupil_v is not None
else jnp.full(ret._Ntot, jnp.nan, dtype=float)
)
ret._time = (
time.copy()
if time is not None
else jnp.full(ret._Ntot, jnp.nan, dtype=float)
)
ret.setCorrelated(is_corr)
return ret
[docs]
def tree_flatten(self):
children = (
(self._x, self._y, self._flux, self._nokeep),
{
"dxdz": self._dxdz,
"dydz": self._dydz,
"wavelength": self._wave,
"pupil_u": self._pupil_u,
"pupil_v": self._pupil_v,
"time": self._time,
"is_corr": self._is_corr,
},
)
aux_data = (self._Ntot,)
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
ret = cls.__new__(cls)
ret._Ntot = aux_data[0]
ret._nokeep = children[0][3]
ret._x = children[0][0]
ret._y = children[0][1]
ret._flux = children[0][2]
ret._dxdz = children[1]["dxdz"]
ret._dydz = children[1]["dydz"]
ret._wave = children[1]["wavelength"]
ret._pupil_u = children[1]["pupil_u"]
ret._pupil_v = children[1]["pupil_v"]
ret._time = children[1]["time"]
ret._is_corr = children[1]["is_corr"]
return ret
[docs]
@implements(_galsim.PhotonArray.size)
def size(self):
return self._Ntot
def __len__(self):
return self._Ntot
@property
def _num_keep(self):
"""The number of actual photons in the array."""
return jnp.sum(~self._nokeep).astype(int)
@_num_keep.setter
def _num_keep(self, num_keep):
"""Set the number of actual photons in the array."""
sinds = jnp.argsort(self._nokeep)
self._sort_by_nokeep(sinds=sinds)
self._nokeep = jnp.arange(self._Ntot) >= num_keep
self._set_self_at_inds(sinds)
@property
@implements(
_galsim.photon_array.PhotonArray.x,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def x(self):
return self._x
@x.setter
def x(self, value):
self._x = self._x.at[:].set(value)
@property
@implements(
_galsim.photon_array.PhotonArray.y,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def y(self):
return self._y
@y.setter
def y(self, value):
self._y = self._y.at[:].set(value)
@property
@implements(
_galsim.photon_array.PhotonArray.flux,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def flux(self):
# we use jax.lax.cond to save some multiplications when
# there are no masked photos.
return jax.lax.cond(
self._Ntot == self._num_keep,
lambda flux, ratio: flux,
lambda flux, ratio: flux * ratio,
jnp.where(self._nokeep, 0.0, self._flux),
self._Ntot / self._num_keep,
)
@flux.setter
def flux(self, value):
self._flux = self._flux.at[:].set(
value
# scale it down to account for scaling in flux getter above
# this factor has to be computed after _nokeep is set above
# so that _num_keep is the right value
/ (self._Ntot / self._num_keep)
)
@property
@implements(
_galsim.photon_array.PhotonArray.dxdz,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def dxdz(self):
return self._dxdz
@dxdz.setter
def dxdz(self, value):
self._dxdz = self._dxdz.at[:].set(value)
self._dydz = _zero_if_needed_on_set(self._dxdz, self._dydz)
@property
@implements(
_galsim.photon_array.PhotonArray.dydz,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def dydz(self):
return self._dydz
@dydz.setter
def dydz(self, value):
self._dydz = self._dydz.at[:].set(value)
self._dxdz = _zero_if_needed_on_set(self._dydz, self._dxdz)
@property
@implements(
_galsim.photon_array.PhotonArray.wavelength,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def wavelength(self):
return self._wave
@wavelength.setter
def wavelength(self, value):
self._wave = self._wave.at[:].set(value)
@property
@implements(
_galsim.photon_array.PhotonArray.pupil_u,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def pupil_u(self):
return self._pupil_u
@pupil_u.setter
def pupil_u(self, value):
self._pupil_u = self._pupil_u.at[:].set(value)
self._pupil_v = _zero_if_needed_on_set(self._pupil_u, self._pupil_v)
@property
@implements(
_galsim.photon_array.PhotonArray.pupil_v,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def pupil_v(self):
return self._pupil_v
@pupil_v.setter
def pupil_v(self, value):
self._pupil_v = self._pupil_v.at[:].set(value)
self._pupil_u = _zero_if_needed_on_set(self._pupil_v, self._pupil_u)
@property
@implements(
_galsim.photon_array.PhotonArray.time,
lax_description="JAX-GalSim PhotonArray properties do not support assignment at indices.",
)
def time(self):
return self._time
@time.setter
def time(self, value):
self._time = self._time.at[:].set(value)
[docs]
@implements(_galsim.photon_array.PhotonArray.hasAllocatedAngles)
def hasAllocatedAngles(self):
return jnp.any(jnp.isfinite(self.dxdz) | jnp.isfinite(self.dydz))
[docs]
@implements(
_galsim.photon_array.PhotonArray.allocateAngles,
lax_description="This is a no-op for JAX-Galsim.",
)
def allocateAngles(self):
pass
[docs]
@implements(_galsim.photon_array.PhotonArray.hasAllocatedWavelengths)
def hasAllocatedWavelengths(self):
return jnp.any(jnp.isfinite(self.wavelength))
[docs]
@implements(
_galsim.photon_array.PhotonArray.allocateWavelengths,
lax_description="This is a no-op for JAX-Galsim.",
)
def allocateWavelengths(self):
pass
[docs]
@implements(_galsim.photon_array.PhotonArray.hasAllocatedPupil)
def hasAllocatedPupil(self):
return jnp.any(jnp.isfinite(self.pupil_u) | jnp.isfinite(self.pupil_v))
[docs]
@implements(
_galsim.photon_array.PhotonArray.allocatePupil,
lax_description="This is a no-op for JAX-Galsim.",
)
def allocatePupil(self):
pass
[docs]
@implements(_galsim.photon_array.PhotonArray.hasAllocatedTimes)
def hasAllocatedTimes(self):
return jnp.any(jnp.isfinite(self.time))
[docs]
@implements(
_galsim.photon_array.PhotonArray.allocateTimes,
lax_description="This is a no-op for JAX-Galsim.",
)
def allocateTimes(self):
return True
[docs]
@implements(_galsim.photon_array.PhotonArray.getTotalFlux)
def getTotalFlux(self):
return self.flux.sum()
[docs]
@implements(_galsim.photon_array.PhotonArray.setTotalFlux)
def setTotalFlux(self, flux):
self.scaleFlux(flux / self.getTotalFlux())
return self
[docs]
@implements(_galsim.photon_array.PhotonArray.scaleFlux)
def scaleFlux(self, scale):
self._flux *= scale
return self
[docs]
@implements(_galsim.photon_array.PhotonArray.scaleXY)
def scaleXY(self, scale):
self._x *= scale
self._y *= scale
return self
def _sort_by_nokeep(self, sinds=None):
# now sort things to keep to the left
if sinds is None:
sinds = jnp.argsort(self._nokeep)
self._x = self._x.at[sinds].get()
self._y = self._y.at[sinds].get()
self._flux = self._flux.at[sinds].get()
self._nokeep = self._nokeep.at[sinds].get()
self._dxdz = self._dxdz.at[sinds].get()
self._dydz = self._dydz.at[sinds].get()
self._wave = self._wave.at[sinds].get()
self._pupil_u = self._pupil_u.at[sinds].get()
self._pupil_v = self._pupil_v.at[sinds].get()
self._time = self._time.at[sinds].get()
return self
def _set_self_at_inds(self, sinds):
self._x = self._x.at[sinds].set(self._x)
self._y = self._y.at[sinds].set(self._y)
self._flux = self._flux.at[sinds].set(self._flux)
self._nokeep = self._nokeep.at[sinds].set(self._nokeep)
self._dxdz = self._dxdz.at[sinds].set(self._dxdz)
self._dydz = self._dydz.at[sinds].set(self._dydz)
self._wave = self._wave.at[sinds].set(self._wave)
self._pupil_u = self._pupil_u.at[sinds].set(self._pupil_u)
self._pupil_v = self._pupil_v.at[sinds].set(self._pupil_v)
self._time = self._time.at[sinds].set(self._time)
return self
[docs]
@implements(_galsim.PhotonArray.assignAt)
def assignAt(self, istart, rhs):
from .deprecated import depr
depr(
"PhotonArray.assignAt",
2.5,
"copyFrom(rhs, slice(istart, istart+rhs.size()))",
)
if istart + rhs.size() > self.size():
raise GalSimValueError(
"The given rhs does not fit into this array starting at %d" % istart,
rhs,
)
s = slice(istart, istart + rhs.size())
return self._copyFrom(rhs, s, slice(None))
[docs]
@implements(
_galsim.PhotonArray.copyFrom,
lax_description="The JAX version of PhotonArray.copyFrom does not raise for out of bounds indices.",
)
def copyFrom(
self,
rhs,
target_indices=slice(None),
source_indices=slice(None),
do_xy=True,
do_flux=True,
do_other=True,
):
return self._copyFrom(
rhs, target_indices, source_indices, do_xy, do_flux, do_other
)
@implements(_galsim.photon_array.PhotonArray._copyFrom)
def _copyFrom(
self,
rhs,
target_indices,
source_indices,
do_xy=True,
do_flux=True,
do_other=True,
):
# Aliases for notational convenience.
s1 = target_indices
s2 = source_indices
@jax.jit
def _cond_set_indices(arr1, arr2, cond_val):
return jax.lax.cond(
cond_val,
lambda arr1, arr2: arr1.at[s1].set(arr2.at[s2].get()),
lambda arr1, arr2: arr1,
arr1,
arr2,
)
old_flux_ratio = self._Ntot / self._num_keep
if do_xy or do_flux or do_other:
self._nokeep = self._nokeep.at[s1].set(rhs._nokeep.at[s2].get())
new_flux_ratio = self._Ntot / self._num_keep
if do_xy:
self._x = self._x.at[s1].set(rhs.x.at[s2].get())
self._y = self._y.at[s1].set(rhs.y.at[s2].get())
if do_flux:
# we first scale the existing fluxes to account for the change in num_keep
self._flux = (
self._flux
# this factor gets us back to true flux
* old_flux_ratio
# this factor gets us back to the internal units
/ new_flux_ratio
)
# next we assign the RHS fluxes accounting for the change in num_keep from the
# RHS to the new flux_ratio
self._flux = self._flux.at[s1].set(
rhs._flux.at[s2].get()
# these factors conserve the flux of the assigned photons
# gets us to the true flux of the photon
* (rhs._Ntot / rhs._num_keep)
# scale it back down to account for scaling later
# this factor has to be computed after _nokeep is set above
# so that _num_keep is the right value
/ new_flux_ratio
)
if do_other:
self._dxdz = _cond_set_indices(
self._dxdz, rhs.dxdz, rhs.hasAllocatedAngles()
)
self._dydz = _cond_set_indices(
self._dydz, rhs.dydz, rhs.hasAllocatedAngles()
)
self._wave = _cond_set_indices(
self._wave, rhs.wavelength, rhs.hasAllocatedWavelengths()
)
self._pupil_u = _cond_set_indices(
self._pupil_u, rhs.pupil_u, rhs.hasAllocatedPupil()
)
self._pupil_v = _cond_set_indices(
self._pupil_v, rhs.pupil_v, rhs.hasAllocatedPupil()
)
self._time = _cond_set_indices(
self._time, rhs.time, rhs.hasAllocatedTimes()
)
return self
def _assign_from_categorical_index(self, cat_inds, cat_ind_to_assign, rhs):
"""Assign the contents of another `PhotonArray` to this one at locations
where cat_ind == cat_ind_to_assign.
"""
msk = cat_ind_to_assign == cat_inds
old_flux_ratio = self._Ntot / self._num_keep
self._nokeep = jnp.where(msk, rhs._nokeep, self._nokeep)
new_flux_ratio = self._Ntot / self._num_keep
rhs_flux_ratio = rhs._Ntot / rhs._num_keep
self._x = jnp.where(msk, rhs._x, self._x)
self._y = jnp.where(msk, rhs._y, self._y)
self._flux = jnp.where(
msk,
rhs._flux * rhs_flux_ratio / new_flux_ratio,
self._flux * old_flux_ratio / new_flux_ratio,
)
self._dxdz = jnp.where(msk, rhs._dxdz, self._dxdz)
self._dydz = jnp.where(msk, rhs._dydz, self._dydz)
self._wave = jnp.where(msk, rhs._wave, self._wave)
self._pupil_u = jnp.where(msk, rhs._pupil_u, self._pupil_u)
self._pupil_v = jnp.where(msk, rhs._pupil_v, self._pupil_v)
self._time = jnp.where(msk, rhs._time, self._time)
return self
[docs]
@implements(_galsim.photon_array.PhotonArray.convolve)
def convolve(self, rhs, rng=None):
if rhs.size() != self.size():
raise GalSimIncompatibleValuesError(
"PhotonArray.convolve with unequal size arrays", self_pa=self, rhs=rhs
)
# We need to make sure that the arrays are sorted by _nokeep before convolving
# we sort them back to their original order after convolving
self_sinds = jnp.argsort(self._nokeep)
rhs_sinds = jnp.argsort(rhs._nokeep)
self._sort_by_nokeep(sinds=self_sinds)
rhs._sort_by_nokeep(sinds=rhs_sinds)
# When two photon arrays are convolved, you basically perturb the positions of one
# by adding the positions of the other. For example, if you have a delta function
# and want to convolve with a Gaussian, then the photon arrays are an array of zeros
# for the delta function and an array of Gaussian draws for the Gaussian. The convolution
# is then implemented by adding the positions of the two arrays.
# The edge case here is if the photons in anb array are correlated. for example, if
# you draw photons from a sum of two profiles, you could have the photons from one
# of the components only at the start of the array and the photons from the other
# component only at the end of the array like this
#
# [A, A, A, ..., A, B, B, B. ..., B]
#
# where A and B represent which component the photon came from. If you convolve two
# photon arrays where both arrays have intenral correlations in the ordering of the
# photons, then you need to randomly sort one of the arrays before the convolution.
# Otherwise you won't properly be adding a random draew from one profile to the other.
# the indexing and PRNG code snippets below handle this case of convolving two internally
# correlated photon arrays.
# these are indicies that randomly sort the RHS's photons.
rng = BaseDeviate(rng)
rsinds = jrng.choice(
rng._state.split_one(),
self._Ntot,
shape=(self.size(),),
replace=False,
)
# these indices do not randomly sort the RHS's photons
nrsinds = jnp.arange(self.size())
# now we randomly sort if both arrays are internally correlated
# however there is a catch. The RHS may not be keeping all of its photons
# (i.e., rhs._nokeep is True for some photons). In this case, we additionally
# sort the random indices by the value of rhs._nokeep so that the photons to be
# kept are still at the front of the array but are in a new random order.
sinds = jax.lax.cond(
self._is_corr & rhs._is_corr,
lambda nrsinds, rsinds: rsinds.at[
jnp.argsort(rhs._nokeep.at[rsinds].get())
].get(),
lambda nrsinds, rsinds: nrsinds,
nrsinds,
rsinds,
)
self.dxdz, self.dydz = jax.lax.cond(
rhs.hasAllocatedAngles() & (~self.hasAllocatedAngles()),
lambda self_dxdz, rhs_dxdz, self_dydz, rhs_dydz, sinds: (
rhs_dxdz.at[sinds].get(),
rhs_dydz.at[sinds].get(),
),
lambda self_dxdz, rhs_dxdz, self_dydz, rhs_dydz, sinds: (
self_dxdz,
self_dydz,
),
self.dxdz,
rhs.dxdz,
self.dydz,
rhs.dydz,
sinds,
)
self.wavelength = jax.lax.cond(
rhs.hasAllocatedWavelengths() & (~self.hasAllocatedWavelengths()),
lambda self_wave, rhs_wave, sinds: rhs_wave.at[sinds].get(),
lambda self_wave, rhs_wave, sinds: self_wave,
self.wavelength,
rhs.wavelength,
sinds,
)
self.pupil_u, self.pupil_v = jax.lax.cond(
rhs.hasAllocatedPupil() & (~self.hasAllocatedPupil()),
lambda self_pupil_u, rhs_pupil_u, self_pupil_v, rhs_pupil_v, sinds: (
rhs_pupil_u.at[sinds].get(),
rhs_pupil_v.at[sinds].get(),
),
lambda self_pupil_u, rhs_pupil_u, self_pupil_v, rhs_pupil_v, sinds: (
self_pupil_u,
self_pupil_v,
),
self.pupil_u,
rhs.pupil_u,
self.pupil_v,
rhs.pupil_v,
sinds,
)
self.time = jax.lax.cond(
rhs.hasAllocatedTimes() & (~self.hasAllocatedTimes()),
lambda self_time, rhs_time, sinds: rhs_time.at[sinds].get(),
lambda self_time, rhs_time, sinds: self_time,
self.time,
rhs.time,
sinds,
)
self._is_corr = self._is_corr | rhs._is_corr
self._x = self._x + rhs._x.at[sinds].get()
self._y = self._y + rhs._y.at[sinds].get()
self._flux = self._flux * rhs._flux.at[sinds].get() * self.size()
# sort the arrays back to their original order
self._set_self_at_inds(self_sinds)
rhs._set_self_at_inds(rhs_sinds)
return self
def __repr__(self):
import numpy as np
s = "galsim.PhotonArray(%r, x=array(%r), y=array(%r), flux=array(%r)" % (
cast_to_python_int(self.size()),
np.array(self.x).tolist(),
np.array(self.y).tolist(),
np.array(self.flux).tolist(),
)
if self.hasAllocatedAngles():
s += ", dxdz=array(%r), dydz=array(%r)" % (
np.array(self.dxdz).tolist(),
np.array(self.dydz).tolist(),
)
if self.hasAllocatedWavelengths():
s += ", wavelength=array(%r)" % (np.array(self.wavelength).tolist())
if self.hasAllocatedPupil():
s += ", pupil_u=array(%r), pupil_v=array(%r)" % (
np.array(self.pupil_u).tolist(),
np.array(self.pupil_v).tolist(),
)
if self.hasAllocatedTimes():
s += ", time=array(%r)" % np.array(self.time).tolist()
s += ", _nokeep=array(%r)" % np.array(self._nokeep).tolist()
s += ")"
return s
def __str__(self):
return "galsim.PhotonArray(%r)" % cast_to_python_int(self.size())
__hash__ = None
def __eq__(self, other):
return self is other or (
isinstance(other, PhotonArray)
and jnp.array_equal(self.x, other.x)
and jnp.array_equal(self.y, other.y)
and jnp.array_equal(self.flux, other.flux)
and jnp.array_equal(self._nokeep, other._nokeep)
and jnp.array_equal(self.dxdz, other.dxdz, equal_nan=True)
and jnp.array_equal(self.dydz, other.dydz, equal_nan=True)
and jnp.array_equal(self.wavelength, other.wavelength, equal_nan=True)
and jnp.array_equal(self.pupil_u, other.pupil_u, equal_nan=True)
and jnp.array_equal(self.pupil_v, other.pupil_v, equal_nan=True)
and jnp.array_equal(self.time, other.time, equal_nan=True)
)
def __ne__(self, other):
return not self == other
[docs]
@implements(
_galsim.PhotonArray.addTo,
lax_description="The JAX equivalent of galsim.PhotonArray.addTo may not raise for undefined bounds.",
)
def addTo(self, image):
if not image.bounds.isDefined():
raise GalSimUndefinedBoundsError(
"Attempting to PhotonArray::addTo an Image with undefined Bounds"
)
_arr, _flux_sum = _add_photons_to_image(
self._x,
self._y,
# this computation is the same as self.flux, but we've left it duplicated here
# so that we don't change this line to self._flux only by accident in the future
jnp.where(self._nokeep, 0.0, self._flux) * self._Ntot / self._num_keep,
image.bounds.xmin,
image.bounds.ymin,
image._array,
)
image._array = image.array.at[...].set(_arr)
return _flux_sum
[docs]
@classmethod
@implements(_galsim.photon_array.PhotonArray.makeFromImage)
def makeFromImage(cls, image, max_flux=1.0, rng=None):
if max_flux <= 0:
raise GalSimRangeError("max_flux must be positive", max_flux, 0.0)
n_per = jnp.clip(jnp.ceil(jnp.abs(image.array) / max_flux), 1).astype(int)
flux_per = (image.array / n_per).ravel()
n_per = n_per.ravel()
inds = jnp.arange(image.array.size)
inds = jnp.repeat(inds, n_per)
yinds, xinds = jnp.unravel_index(inds, image.array.shape)
xedges = jnp.arange(image.bounds.xmin, image.bounds.xmax + 2) - 0.5
yedges = jnp.arange(image.bounds.ymin, image.bounds.ymax + 2) - 0.5
# now we draw the position within the pixel
ud = UniformDeviate(rng)
photons = cls(n_per.sum())
photons.x = ud.generate(photons.x) + xedges[xinds]
photons.y = ud.generate(photons.y) + yedges[yinds]
photons.flux = flux_per[inds]
if image.scale is not None:
photons.scaleXY(image.scale)
return photons
[docs]
@implements(_galsim.photon_array.PhotonArray.write)
def write(self, file_name):
import numpy as np
from jax_galsim import fits
cols = []
cols.append(pyfits.Column(name="id", format="J", array=range(self.size())))
cols.append(pyfits.Column(name="x", format="D", array=np.array(self.x)))
cols.append(pyfits.Column(name="y", format="D", array=np.array(self.y)))
cols.append(pyfits.Column(name="flux", format="D", array=np.array(self.flux)))
cols.append(
pyfits.Column(name="_nokeep", format="L", array=np.array(self._nokeep))
)
if self.hasAllocatedAngles():
cols.append(
pyfits.Column(name="dxdz", format="D", array=np.array(self.dxdz))
)
cols.append(
pyfits.Column(name="dydz", format="D", array=np.array(self.dydz))
)
if self.hasAllocatedWavelengths():
cols.append(
pyfits.Column(
name="wavelength", format="D", array=np.array(self.wavelength)
)
)
if self.hasAllocatedPupil():
cols.append(
pyfits.Column(name="pupil_u", format="D", array=np.array(self.pupil_u))
)
cols.append(
pyfits.Column(name="pupil_v", format="D", array=np.array(self.pupil_v))
)
if self.hasAllocatedTimes():
cols.append(
pyfits.Column(name="time", format="D", array=np.array(self.time))
)
cols = pyfits.ColDefs(cols)
table = pyfits.BinTableHDU.from_columns(cols)
fits.writeFile(file_name, table)
[docs]
@classmethod
@implements(_galsim.photon_array.PhotonArray.read)
def read(cls, file_name):
with pyfits.open(file_name) as fits:
data = fits[1].data
N = len(data)
names = data.columns.names
photons = cls(
N,
x=jnp.array(cast_numpy_array_to_native_byte_order(data["x"])),
y=jnp.array(cast_numpy_array_to_native_byte_order(data["y"])),
flux=jnp.array(cast_numpy_array_to_native_byte_order(data["flux"])),
)
photons._nokeep = jnp.array(
cast_numpy_array_to_native_byte_order(data["_nokeep"])
)
if "dxdz" in names:
photons.dxdz = jnp.array(
cast_numpy_array_to_native_byte_order(data["dxdz"])
)
photons.dydz = jnp.array(
cast_numpy_array_to_native_byte_order(data["dydz"])
)
if "wavelength" in names:
photons.wavelength = jnp.array(
cast_numpy_array_to_native_byte_order(data["wavelength"])
)
if "pupil_u" in names:
photons.pupil_u = jnp.array(
cast_numpy_array_to_native_byte_order(data["pupil_u"])
)
photons.pupil_v = jnp.array(
cast_numpy_array_to_native_byte_order(data["pupil_v"])
)
if "time" in names:
photons.time = jnp.array(
cast_numpy_array_to_native_byte_order(data["time"])
)
return photons
@jax.jit
def _add_photons_to_image(x, y, flux, xmin, ymin, arr):
xinds = jnp.floor(x - xmin + 0.5).astype(int)
yinds = jnp.floor(y - ymin + 0.5).astype(int)
# the jax documentation says that they drop out of bounds indices,
# but the galsim unit tests reveal that without the check below,
# the indices are not dropped.
# I think maybe it is only indices beyond the end of the array that are
# dropped and negative indices wrap around
good = (xinds >= 0) & (xinds < arr.shape[1]) & (yinds >= 0) & (yinds < arr.shape[0])
_flux = jnp.where(good, flux, 0.0)
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
if jnp.issubdtype(arr.dtype, jnp.integer):
_arr = arr.astype(float).at[yinds, xinds].add(_flux.astype(float))
_arr = jnp.around(_arr).astype(arr.dtype)
else:
_arr = arr.at[yinds, xinds].add(_flux.astype(arr.dtype))
return _arr, _flux.sum()
def _zero_if_needed_on_set(arr_to_test, arr_to_zero):
return jax.lax.cond(
jnp.any(jnp.isfinite(arr_to_test)) & jnp.all(~jnp.isfinite(arr_to_zero)),
lambda atz: jnp.zeros_like(atz),
lambda atz: atz,
arr_to_zero,
)