Source code for jax_galsim.convolve

import galsim as _galsim
import jax.numpy as jnp
from galsim.errors import galsim_warn
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import implements
from jax_galsim.gsobject import GSObject
from jax_galsim.gsparams import GSParams
from jax_galsim.photon_array import PhotonArray


[docs] @implements( _galsim.Convolve, lax_description="""Does not support ChromaticConvolutions""", ) def Convolve(*args, **kwargs): if len(args) == 0: raise TypeError("At least one GSObject must be provided.") elif len(args) == 1: if isinstance(args[0], GSObject): args = [args[0]] elif isinstance(args[0], list) or isinstance(args[0], tuple): args = args[0] else: raise TypeError( "Single input argument must be a GSObject, " + "or a (possibly mixed) list of them." ) # else args is already the list of objects return Convolution(*args, **kwargs)
[docs] @implements( _galsim.Convolution, lax_description="""Only supports 'fft' convolution.""", ) @register_pytree_node_class class Convolution(GSObject): def __init__(self, *args, **kwargs): # First check for number of arguments != 0 if len(args) == 0: raise TypeError("At least one GSObject must be provided.") elif len(args) == 1: if isinstance(args[0], GSObject): args = [args[0]] elif isinstance(args[0], list) or isinstance(args[0], tuple): args = args[0] else: raise TypeError( "Single input argument must be a GSObject or list of them." ) # else args is already the list of objects real_space = kwargs.pop("real_space", None) gsparams = kwargs.pop("gsparams", None) self._propagate_gsparams = kwargs.pop("propagate_gsparams", True) # Make sure there is nothing left in the dict. if kwargs: raise TypeError( "Convolution constructor got unexpected keyword argument(s): %s" % kwargs.keys() ) # Check whether to perform real space convolution... # Start by checking if all objects have a hard edge. hard_edge = True for obj in args: if not isinstance(obj, GSObject): raise TypeError( "Arguments to Convolution must be GSObjects, not %s" % obj ) if not obj.has_hard_edges: hard_edge = False if real_space is None: # The automatic determination is to use real_space if 2 items, both with hard edges. if len(args) <= 2: real_space = hard_edge else: real_space = False elif bool(real_space) != real_space: raise TypeError("real_space must be a boolean") # Warn if doing DFT convolution for objects with hard edges if not real_space and hard_edge: if len(args) == 2: galsim_warn( "Doing convolution of 2 objects, both with hard edges. " "This might be more accurate with `real_space=True`, " "but this functionality has not yet been implemented in JAX-Galsim." ) else: galsim_warn( "Doing convolution where all objects have hard edges. " "There might be some inaccuracies due to ringing in k-space." ) if real_space: # Can't do real space if nobj > 2 if len(args) > 2: galsim_warn( "Real-space convolution of more than 2 objects is not implemented. " "Switching to DFT method." ) real_space = False # Also can't do real space if any object is not analytic, so check for that. else: for obj in args: if not obj.is_analytic_x: galsim_warn( "A component to be convolved is not analytic in real space. " "Cannot use real space convolution. Switching to DFT method." ) real_space = False break # Save the construction parameters (as they are at this point) as attributes so they # can be inspected later if necessary. if bool(real_space): raise NotImplementedError("Real-space convolutions are not implemented") self._real_space = bool(real_space) # Figure out what gsparams to use if gsparams is None: # If none is given, take the most restrictive combination from the obj_list. self._gsparams = GSParams.combine([obj.gsparams for obj in args]) else: # If something explicitly given, then use that. self._gsparams = GSParams.check(gsparams) # Apply gsparams to all in obj_list. if self._propagate_gsparams: self._obj_list = [obj.withGSParams(self._gsparams) for obj in args] else: self._obj_list = args # Save the list of objects self._params = {"obj_list": self._obj_list} @property @implements(_galsim.Convolution.obj_list) def obj_list(self): return self._obj_list @property @implements(_galsim.Convolution.real_space) def real_space(self): return self._real_space
[docs] @implements(_galsim.Convolution.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: return self from copy import copy ret = copy(self) ret._gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) if self._propagate_gsparams: ret._obj_list = [obj.withGSParams(ret._gsparams) for obj in self.obj_list] return ret
def __eq__(self, other): return self is other or ( isinstance(other, Convolution) and self.obj_list == other.obj_list and self.real_space == other.real_space and self.gsparams == other.gsparams and self._propagate_gsparams == other._propagate_gsparams ) def __hash__(self): return hash( ( "galsim.Convolution", tuple(self.obj_list), self.real_space, self.gsparams, self._propagate_gsparams, ) ) def __repr__(self): return ( "galsim.Convolution(%r, real_space=%r, gsparams=%r, propagate_gsparams=%r)" % ( self.obj_list, self.real_space, self.gsparams, self._propagate_gsparams, ) ) def __str__(self): str_list = [str(obj) for obj in self.obj_list] s = "galsim.Convolve(%s" % (", ".join(str_list)) if self.real_space: s += ", real_space=True" s += ")" return s @property def _maxk(self): maxk_list = [obj.maxk for obj in self.obj_list] return jnp.min(jnp.array(maxk_list)) @property def _stepk(self): # This is approximate. stepk ~ 2pi/R # Assume R_final^2 = Sum(R_i^2) # So 1/stepk^2 = 1/Sum(1/stepk_i^2) inv_stepksq_list = [obj.stepk ** (-2) for obj in self.obj_list] return 1.0 / jnp.sqrt(jnp.sum(jnp.array(inv_stepksq_list))) @property def _has_hard_edges(self): return len(self.obj_list) == 1 and self.obj_list[0].has_hard_edges @property def _is_axisymmetric(self): axi_list = [obj.is_axisymmetric for obj in self.obj_list] return jnp.all(jnp.array(axi_list)).item() @property def _is_analytic_x(self): if len(self.obj_list) == 1: return self.obj_list[0].is_analytic_x elif self.real_space and len(self.obj_list) == 2: ax_list = [obj.is_analytic_x for obj in self.obj_list] return jnp.all(jnp.arry(ax_list)).item() # return a bool else: return False @property def _is_analytic_k(self): ak_list = [obj.is_analytic_k for obj in self.obj_list] return jnp.all(jnp.array(ak_list)).item() # return a bool @property def _centroid(self): cen_list = [obj.centroid for obj in self.obj_list] return sum( cen_list[1:], cen_list[0] ) # gives a Position object with x=sum_i x_i and y=sum_i y_i @property def _flux(self): flux_list = [obj.flux for obj in self.obj_list] return jnp.prod(jnp.array(flux_list)) def _calc_pn(self): pos_list = [obj.positive_flux for obj in self.obj_list] neg_list = [obj.negative_flux for obj in self.obj_list] p_net = pos_list[0] n_net = neg_list[0] for p, n in zip(pos_list[1:], neg_list[1:]): p_net, n_net = p_net * p + n_net * n, p_net * n + n_net * p return p_net, n_net @property def _positive_flux(self): return self._calc_pn()[0] @property def _negative_flux(self): return self._calc_pn()[1] @property def _flux_per_photon(self): return self._calculate_flux_per_photon() @property def _max_sb(self): # This one is probably the least accurate of all the estimates of maxSB. # The calculation is based on the exact value for Gaussians. # maxSB = flux / 2pi sigma^2 # When convolving multiple Gaussians together, the sigma^2 values add: # sigma_final^2 = Sum_i sigma_i^2 # from which we can calculate # maxSB = flux_final / 2pi sigma_final^2 # or # maxSB = flux_final / Sum_i (flux_i / maxSB_i) # # For non-Gaussians, this procedure will tend to produce an over-estimate of the # true maximum SB. Non-Gaussian profiles tend to have peakier parts which get smoothed # more than the Gaussian does. So this is likely to be too high, which is acceptable. area_list = [obj.flux / obj.max_sb for obj in self.obj_list] return self.flux / jnp.sum(jnp.array(area_list)) def _xValue(self, pos): raise NotImplementedError("Real-space convolutions are not implemented") def _kValue(self, kpos): kv_list = [ obj.kValue(kpos) for obj in self.obj_list ] # In GalSim one uses obj.kValue return jnp.prod(jnp.array(kv_list)) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): raise NotImplementedError("Real-space convolutions are not implemented") @implements(_galsim.Convolution._shoot) def _shoot(self, photons, rng): self.obj_list[0]._shoot(photons, rng) # It may be necessary to shuffle when convolving because we do not have a # guarantee that the convolvee's photons are uncorrelated, e.g., they might # both have their negative ones at the end. # However, this decision is now made by the convolve method. for obj in self.obj_list[1:]: p1 = PhotonArray(len(photons)) obj._shoot(p1, rng) photons.convolve(p1, rng) def _drawKImage(self, image, jac=None): image = self.obj_list[0]._drawKImage(image, jac) if len(self.obj_list) > 1: for obj in self.obj_list[1:]: image *= obj._drawKImage(image, jac) return image
[docs] def tree_flatten(self): """This function flattens the GSObject into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing children = (self.params,) # Define auxiliary static data that doesn’t need to be traced aux_data = { "gsparams": self.gsparams, "propagate_gsparams": self._propagate_gsparams, "real_space": self._real_space, } return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(children[0]["obj_list"], **aux_data)
[docs] @implements( _galsim.convolve.Deconvolve, lax_description="Does not support ChromaticDeconvolution", ) def Deconvolve(obj, gsparams=None, propagate_gsparams=True): if isinstance(obj, GSObject): return Deconvolution( obj, gsparams=gsparams, propagate_gsparams=propagate_gsparams ) else: raise TypeError( "Argument to Deconvolve must be either a GSObject or a ChromaticObject." )
[docs] @implements(_galsim.convolve.Deconvolution) @register_pytree_node_class class Deconvolution(GSObject): _has_hard_edges = False _is_analytic_x = False def __init__(self, obj, gsparams=None, propagate_gsparams=True): if not isinstance(obj, GSObject): raise TypeError("Argument to Deconvolution must be a GSObject.") # Save the original object as an attribute, so it can be inspected later if necessary. self._gsparams = GSParams.check(gsparams, obj.gsparams) self._propagate_gsparams = propagate_gsparams if self._propagate_gsparams: self._orig_obj = obj.withGSParams(self._gsparams) else: self._orig_obj = obj @property def _min_acc_kvalue(self): return self._orig_obj.flux * self.gsparams.kvalue_accuracy @property def _inv_min_acc_kvalue(self): return 1.0 / self._min_acc_kvalue @property @implements(_galsim.Deconvolution.orig_obj) def orig_obj(self): return self._orig_obj @property def _noise(self): if self.orig_obj.noise is not None: galsim_warn("Unable to propagate noise in galsim.Deconvolution") return None
[docs] @implements(_galsim.Deconvolution.withGSParams) def withGSParams(self, gsparams=None, **kwargs): if gsparams == self.gsparams: return self from copy import copy ret = copy(self) ret._gsparams = GSParams.check(gsparams, self.gsparams, **kwargs) if self._propagate_gsparams: ret._orig_obj = self._orig_obj.withGSParams(ret._gsparams) return ret
def __eq__(self, other): return self is other or ( isinstance(other, Deconvolution) and self.orig_obj == other.orig_obj and self.gsparams == other.gsparams and self._propagate_gsparams == other._propagate_gsparams ) def __hash__(self): return hash( ( "galsim.Deconvolution", self.orig_obj, self.gsparams, self._propagate_gsparams, ) ) def __repr__(self): return "galsim.Deconvolution(%r, gsparams=%r, propagate_gsparams=%r)" % ( self.orig_obj, self.gsparams, self._propagate_gsparams, ) def __str__(self): return "galsim.Deconvolve(%s)" % self.orig_obj def _prepareDraw(self): self.orig_obj._prepareDraw() @property def _maxk(self): return self.orig_obj.maxk @property def _stepk(self): return self.orig_obj.stepk @property def _is_axisymmetric(self): return self.orig_obj.is_axisymmetric @property def _is_analytic_k(self): return self.orig_obj.is_analytic_k @property def _centroid(self): return -self.orig_obj.centroid @property def _flux(self): return 1.0 / self.orig_obj.flux @property def _max_sb(self): # The only way to really give this any meaning is to consider it in the context # of being part of a larger convolution with other components. The calculation # of maxSB for Convolve is # maxSB = flux_final / Sum_i (flux_i / maxSB_i) # # A deconvolution will contribute a -sigma^2 to the sum, so a logical choice for # maxSB is to have flux / maxSB = -flux_adaptee / maxSB_adaptee, so its contribution # to the Sum_i 2pi sigma^2 is to subtract its adaptee's value of sigma^2. # # maxSB = -flux * maxSB_adaptee / flux_adaptee # = -maxSB_adaptee / flux_adaptee^2 # return -self.orig_obj.max_sb / self.orig_obj.flux**2 def _kValue(self, pos): # Really, for very low original kvalues, this gets very high, which can be unstable # in the presence of noise. So if the original value is less than min_acc_kvalue, # we instead just return 1/min_acc_kvalue rather than the real inverse. kval = self.orig_obj._kValue(pos) return jnp.where( jnp.abs(kval) < self._min_acc_kvalue, self._inv_min_acc_kvalue, 1.0 / kval, ) def _drawKImage(self, image, jac=None): image = self.orig_obj._drawKImage(image, jac) image._array = image._array.at[...].set( jnp.where( jnp.abs(image.array) > self._min_acc_kvalue, 1.0 / image.array, self._inv_min_acc_kvalue, ) ) kx, ky = image.get_pixel_centers() _jac = jnp.eye(2) if jac is None else jac # N.B. The jacobian is transposed in k space. This is not a typo. kx, ky = ( (_jac[0, 0] * kx + _jac[1, 0] * ky), (_jac[0, 1] * kx + _jac[1, 1] * ky), ) ksq = (kx**2 + ky**2) * image.scale**2 # Set to zero outside of nominal maxk so as not to amplify high frequencies. image._array = image._array.at[...].set( jnp.where( ksq > self.maxk**2, 0.0, image.array, ) ) return image
[docs] def tree_flatten(self): """This function flattens the GSObject into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing children = (self._orig_obj,) # Define auxiliary static data that doesn’t need to be traced aux_data = { "gsparams": self.gsparams, "propagate_gsparams": self._propagate_gsparams, } return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(children[0], **aux_data)