Source code for jax_galsim.deltafunction
import galsim as _galsim
import jax
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class
from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable, implements
from jax_galsim.gsobject import GSObject
[docs]
@implements(_galsim.DeltaFunction)
@register_pytree_node_class
class DeltaFunction(GSObject):
_opt_params = {"flux": float}
_mock_inf = (
1.0e300 # Some arbitrary very large number to use when we need infinity.
)
_has_hard_edges = False
_is_axisymmetric = True
_is_analytic_x = False
_is_analytic_k = True
def __init__(self, flux=1.0, gsparams=None):
super().__init__(flux=flux, gsparams=gsparams)
def __hash__(self):
return hash(("galsim.DeltaFunction", ensure_hashable(self.flux), self.gsparams))
def __repr__(self):
return "galsim.DeltaFunction(flux=%r, gsparams=%r)" % (
ensure_hashable(self.flux),
self.gsparams,
)
def __str__(self):
s = "galsim.DeltaFunction("
if self.flux != 1.0:
s += "flux=%s" % self.flux
s += ")"
return s
@property
def _maxk(self):
return DeltaFunction._mock_inf
@property
def _stepk(self):
return DeltaFunction._mock_inf
@property
def _max_sb(self):
return DeltaFunction._mock_inf
def _xValue(self, pos):
return jax.lax.cond(
jnp.array(pos.x == 0.0, dtype=bool) & jnp.array(pos.y == 0.0, dtype=bool),
lambda *a: DeltaFunction._mock_inf,
lambda *a: 0.0,
)
def _kValue(self, kpos):
# this is a wasteful and fancy way to get the shape to broadcast to
# to match the input kpos
return self.flux + kpos.x * (0.0 + 0.0j)
@implements(_galsim.DeltaFunction._shoot)
def _shoot(self, photons, rng):
flux_per_photon = self.flux / photons.size()
photons.x = 0.0
photons.y = 0.0
photons.flux = flux_per_photon
def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0):
_jac = jnp.eye(2) if jac is None else jac
return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling)
def _drawKImage(self, image, jac=None):
_jac = jnp.eye(2) if jac is None else jac
return draw_by_kValue(self, image, _jac)
[docs]
@implements(_galsim.DeltaFunction.withFlux)
def withFlux(self, flux):
return DeltaFunction(flux=flux, gsparams=self.gsparams)