import galsim as _galsim
import jax
import jax.numpy as jnp
import numpy as np
from jax.tree_util import register_pytree_node_class
from jax_galsim.core.utils import implements
from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.position import PositionD
from jax_galsim.random import BaseDeviate
[docs]
@implements(
_galsim.Add, lax_description="Does not support ``ChromaticObject`` at this point."
)
def Add(*args, **kwargs):
return Sum(*args, **kwargs)
[docs]
@implements(
_galsim.Sum, lax_description="Does not support ``ChromaticObject`` at this point."
)
@register_pytree_node_class
class Sum(GSObject):
def __init__(self, *args, gsparams=None, propagate_gsparams=True):
self._propagate_gsparams = propagate_gsparams
if len(args) == 0:
raise TypeError("At least one GSObject must be provided.")
elif len(args) == 1:
# 1 argument. Should be either a GSObject or a list of GSObjects
if isinstance(args[0], GSObject):
args = [args[0]]
elif isinstance(args[0], list) or isinstance(args[0], tuple):
args = args[0]
else:
raise TypeError(
"Single input argument must be a GSObject or list of them."
)
# else args is already the list of objects
# Consolidate args for Sums of Sums...
new_args = []
for a in args:
if isinstance(a, Sum):
new_args.extend(a.params["obj_list"])
else:
new_args.append(a)
args = new_args
for obj in args:
if not isinstance(obj, GSObject):
raise TypeError("Arguments to Sum must be GSObjects, not %s" % obj)
# Figure out what gsparams to use
if gsparams is None:
# If none is given, take the most restrictive combination from the obj_list.
self._gsparams = GSParams.combine([obj.gsparams for obj in args])
else:
# If something explicitly given, then use that.
self._gsparams = GSParams.check(gsparams)
# Apply gsparams to all in obj_list.
if self._propagate_gsparams:
args = [obj.withGSParams(self._gsparams) for obj in args]
# Save the list as an attribute, so it can be inspected later if necessary.
self._params = {"obj_list": args}
@property
@implements(_galsim.Sum.obj_list)
def obj_list(self):
return self._params["obj_list"]
@property
@implements(_galsim.Sum.flux)
def flux(self):
flux_list = jnp.array([obj.flux for obj in self.obj_list])
return jnp.sum(flux_list)
[docs]
@implements(_galsim.Sum.withGSParams)
def withGSParams(self, gsparams=None, **kwargs):
if gsparams == self.gsparams:
return self
ret = self.__class__(
self.params["obj_list"],
gsparams=GSParams.check(gsparams, self.gsparams, **kwargs),
propagate_gsparams=self._propagate_gsparams,
)
return ret
def __hash__(self):
return hash(
(
"galsim.Sum",
tuple(self.obj_list),
self.gsparams,
self._propagate_gsparams,
)
)
def __repr__(self):
return "galsim.Sum(%r, gsparams=%r, propagate_gsparams=%r)" % (
self.obj_list,
self.gsparams,
self._propagate_gsparams,
)
def __str__(self):
str_list = [str(obj) for obj in self.obj_list]
return "(" + " + ".join(str_list) + ")"
@property
def _maxk(self):
maxk_list = jnp.array([obj.maxk for obj in self.obj_list])
return jnp.max(maxk_list)
@property
def _stepk(self):
stepk_list = jnp.array([obj.stepk for obj in self.obj_list])
return jnp.min(stepk_list)
@property
def _has_hard_edges(self):
hard_list = [obj.has_hard_edges for obj in self.obj_list]
return bool(np.any(hard_list))
@property
def _is_axisymmetric(self):
axi_list = [obj.is_axisymmetric for obj in self.obj_list]
return bool(np.all(axi_list))
@property
def _is_analytic_x(self):
ax_list = [obj.is_analytic_x for obj in self.obj_list]
return bool(np.all(ax_list))
@property
def _is_analytic_k(self):
ak_list = [obj.is_analytic_k for obj in self.obj_list]
return bool(np.all(ak_list))
@property
def _centroid(self):
cen_x_arr = jnp.array([obj.centroid.x * obj.flux for obj in self.obj_list])
cen_y_arr = jnp.array([obj.centroid.y * obj.flux for obj in self.obj_list])
return PositionD(jnp.sum(cen_x_arr) / self.flux, jnp.sum(cen_y_arr) / self.flux)
@property
def _max_sb(self):
sb_list = jnp.array([obj.max_sb for obj in self.obj_list])
return jnp.sum(sb_list)
def _xValue(self, pos):
xv_list = jnp.array([obj._xValue(pos) for obj in self.obj_list])
return jnp.sum(xv_list, axis=0)
def _kValue(self, pos):
kv_list = jnp.array([obj._kValue(pos) for obj in self.obj_list])
return jnp.sum(kv_list, axis=0)
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
image = self.obj_list[0]._drawReal(image, jac, offset, flux_scaling)
if len(self.obj_list) > 1:
for obj in self.obj_list[1:]:
image += obj._drawReal(image, jac, offset, flux_scaling)
return image
def _drawKImage(self, image, jac=None):
image = self.obj_list[0]._drawKImage(image, jac)
if len(self.obj_list) > 1:
for obj in self.obj_list[1:]:
image += obj._drawKImage(image, jac)
return image
@property
def _positive_flux(self):
pflux_list = jnp.array([obj.positive_flux for obj in self.obj_list])
return jnp.sum(pflux_list)
@property
def _negative_flux(self):
nflux_list = jnp.array([obj.negative_flux for obj in self.obj_list])
return jnp.sum(nflux_list)
@property
def _flux_per_photon(self):
return self._calculate_flux_per_photon()
@implements(_galsim.Sum._shoot)
def _shoot(self, photons, rng):
tot_flux = self.positive_flux + self.negative_flux
fluxes = jnp.array(
[obj.positive_flux + obj.negative_flux for obj in self.obj_list]
)
# for a sum of objects, we use a slightly different approach than galsim did
# as of version 2.5
# galsim uses a binomial distribution to compute the number of photons per object
# we take an equivalent but different approach in order to use fixed size arrays
# of photons. it means we draw more photons but the code is JIT compilable and a bit simpler
#
# this all works as follows:
#
# - for each photon, we draw from a categorical distribution with probabilities
# proportional to the total absolute fluxes of the objects.
# - we then shoot the photons from each object and rescale the fluxes (see comment below)
# - finally, we get the photons that correspond to this object in the cetegorical distribution
# and assign them to the photons object there is a special private method on the
# PhotonArray that does this assignment
#
# one nice thing about this is that the photons come out pre-shuffled and so we don't have
# to mark them as correlated.
rng = BaseDeviate(rng)
key = rng._state.split_one()
cat_inds = jax.random.choice(
key,
len(self.obj_list),
shape=(len(photons),),
replace=True,
p=fluxes / tot_flux,
)
for i, obj in enumerate(self.obj_list):
pa = obj.shoot(photons.size(), rng=rng)
# now we rescale the fluxes of the photons
# in galsim, photons end up with a flux that is
#
# fluxes[i] / thisN * tot_flux / photons.size() * thisN / fluxes[i]
# = tot_flux / photons.size()
#
# our photons start with a flux of
#
# flux[i] / photons.size()
#
# so we scale by a factor of
#
# _scale_fac = tot_flux / fluxes[i]
_scale_fac = tot_flux / fluxes[i]
pa.scaleFlux(_scale_fac)
photons._assign_from_categorical_index(cat_inds, i, pa)
[docs]
def tree_flatten(self):
"""This function flattens the GSObject into a list of children
nodes that will be traced by JAX and auxiliary static data."""
# Define the children nodes of the PyTree that need tracing
children = (self.params,)
# Define auxiliary static data that doesn’t need to be traced
aux_data = {
"gsparams": self.gsparams,
"propagate_gsparams": self._propagate_gsparams,
}
return (children, aux_data)
[docs]
@classmethod
def tree_unflatten(cls, aux_data, children):
"""Recreates an instance of the class from flatten representation"""
return cls(children[0]["obj_list"], **aux_data)