Source code for jax_galsim.box

import galsim as _galsim
import jax.numpy as jnp
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.draw import draw_by_kValue, draw_by_xValue
from jax_galsim.core.utils import ensure_hashable, implements
from jax_galsim.gsobject import GSObject
from jax_galsim.random import UniformDeviate


[docs] @implements(_galsim.Box) @register_pytree_node_class class Box(GSObject): _has_hard_edges = True _is_axisymmetric = False _is_analytic_x = True _is_analytic_k = True def __init__(self, width, height, flux=1.0, gsparams=None): super().__init__(width=width, height=height, flux=flux, gsparams=gsparams) @property def _minL(self): return jnp.minimum(self.width, self.height) @property def _maxL(self): return jnp.maximum(self.width, self.height) @property @implements(_galsim.Box.width) def width(self): return self.params["width"] @property @implements(_galsim.Box.height) def height(self): return self.params["height"] def __hash__(self): return hash( ( "galsim.Box", ensure_hashable(self.width), ensure_hashable(self.height), ensure_hashable(self.flux), self.gsparams, ) ) def __repr__(self): return "galsim.Box(width=%r, height=%r, flux=%r, gsparams=%r)" % ( ensure_hashable(self.width), ensure_hashable(self.height), ensure_hashable(self.flux), self.gsparams, ) def __str__(self): s = "galsim.Box(width=%s, height=%s" % ( ensure_hashable(self.width), ensure_hashable(self.height), ) if self.flux != 1.0: s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s @property def _maxk(self): return 2.0 / (self.gsparams.maxk_threshold * self._minL) @property def _stepk(self): return jnp.pi / self._maxL @property def _max_sb(self): return self.flux / (self.width * self.height) def _xValue(self, pos): norm = self.flux / (self.width * self.height) return jnp.where( 2.0 * jnp.abs(pos.x) < self.width, jnp.where(2.0 * jnp.abs(pos.y) < self.height, norm, 0.0), 0.0, ) def _kValue(self, kpos): _wo2pi = self.width / (2.0 * jnp.pi) _ho2pi = self.height / (2.0 * jnp.pi) return self.flux * jnp.sinc(kpos.x * _wo2pi) * jnp.sinc(kpos.y * _ho2pi) def _drawReal(self, image, jac=None, offset=(0.0, 0.0), flux_scaling=1.0): _jac = jnp.eye(2) if jac is None else jac return draw_by_xValue(self, image, _jac, jnp.asarray(offset), flux_scaling) def _drawKImage(self, image, jac=None): _jac = jnp.eye(2) if jac is None else jac return draw_by_kValue(self, image, _jac)
[docs] @implements(_galsim.Box.withFlux) def withFlux(self, flux): return Box( width=self.width, height=self.height, flux=flux, gsparams=self.gsparams )
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls( children[0]["width"], children[0]["height"], flux=children[0]["flux"], **aux_data, )
@implements(_galsim.Box._shoot) def _shoot(self, photons, rng): ud = UniformDeviate(rng) # this does not fill arrays like in galsim photons.x = (ud.generate(photons.x) - 0.5) * self.width photons.y = (ud.generate(photons.y) - 0.5) * self.height photons.flux = self.flux / photons.size()
[docs] @implements(_galsim.Pixel) @register_pytree_node_class class Pixel(Box): def __init__(self, scale, flux=1.0, gsparams=None): super(Pixel, self).__init__( width=scale, height=scale, flux=flux, gsparams=gsparams ) @property @implements(_galsim.Pixel.scale) def scale(self): return self.width def __repr__(self): return "galsim.Pixel(scale=%r, flux=%r, gsparams=%r)" % ( ensure_hashable(self.scale), ensure_hashable(self.flux), self.gsparams, ) def __str__(self): s = "galsim.Pixel(scale=%s" % (ensure_hashable(self.scale),) if self.flux != 1.0: s += ", flux=%s" % (ensure_hashable(self.flux),) s += ")" return s
[docs] @implements(_galsim.Pixel.withFlux) def withFlux(self, flux): return Pixel(scale=self.scale, flux=flux, gsparams=self.gsparams)
[docs] def tree_flatten(self): children = (self.scale, self.flux) aux_data = {"gsparams": self.gsparams} return (children, aux_data)
[docs] @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" return cls(*children, **aux_data)