from collections import namedtuple
import jax
import jax.numpy as jnp
from jax_galsim.random import PoissonDeviate
[docs]
def draw_by_xValue(
gsobject, image, jacobian=jnp.eye(2), offset=jnp.zeros(2), flux_scaling=1.0
):
"""Utility function to draw a real-space GSObject into an Image."""
# putting the import here to avoid circular imports
from jax_galsim import Image, PositionD
# Applies flux scaling to compensate for pixel scale
# See SBProfile.draw()
flux_scaling *= image.scale**2
# Create an array of coordinates
coords = jnp.stack(image.get_pixel_centers(), axis=-1)
coords = coords * image.scale # Scale by the image pixel scale
coords = coords - offset # Add the offset
# Apply the jacobian transformation
inv_jacobian = jnp.linalg.inv(jacobian)
_, logdet = jnp.linalg.slogdet(inv_jacobian)
coords = jnp.dot(coords, inv_jacobian.T)
flux_scaling *= jnp.exp(logdet)
# Draw the object
im = jax.vmap(lambda *args: gsobject._xValue(PositionD(*args)))(
coords[..., 0], coords[..., 1]
)
# Apply the flux scaling
im *= flux_scaling
# jax-galsim's rounding of float-to-int is platform dependent
# so we explicitly round to ints if needed
if jnp.issubdtype(im.dtype, jnp.floating) and jnp.issubdtype(
image.dtype, jnp.integer
):
im = jnp.around(im)
# Return an image
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
[docs]
def draw_by_kValue(gsobject, image, jacobian=jnp.eye(2)):
# putting the import here to avoid circular imports
from jax_galsim import Image, PositionD
# Create an array of coordinates
coords = jnp.stack(image.get_pixel_centers(), axis=-1)
coords = coords * image.scale # Scale by the image pixel scale
coords = jnp.dot(coords, jacobian)
# Draw the object
im = jax.vmap(lambda *args: gsobject._kValue(PositionD(*args)))(
coords[..., 0], coords[..., 1]
)
im = im.astype(image.dtype)
# Return an image
return Image(array=im, bounds=image.bounds, wcs=image.wcs, _check_bounds=False)
[docs]
def apply_kImage_phases(offset, image, jacobian=jnp.eye(2)):
# putting the import here to avoid circular imports
from jax_galsim import Image, PositionD
# Create an array of coordinates
kcoords = jnp.stack(image.get_pixel_centers(), axis=-1)
kcoords = kcoords * image.scale # Scale by the image pixel scale
kcoords = jnp.dot(kcoords, jacobian)
cenx, ceny = offset.x, offset.y
# flux Exp(-i (kx cx + kxy cx + kyx cy + ky cy ) )
# NB: seems that tere is no jax.lax.polar equivalent to c++ std::polar function
def phase(kpos):
arg = -(kpos.x * cenx + kpos.y * ceny)
return jnp.cos(arg) + 1j * jnp.sin(arg)
im_phase = jax.vmap(lambda *args: phase(PositionD(*args)))(
kcoords[..., 0], kcoords[..., 1]
)
return Image(
array=image.array * im_phase,
bounds=image.bounds,
wcs=image.wcs,
_check_bounds=False,
)
_NPhotonsData = namedtuple(
"NPhotonsData",
[
"n_photons",
"flux",
"flux_per_photon", # also called eta_factor below
"max_sb",
"rng",
"poisson_flux",
"max_extra_noise",
],
)
[docs]
def calculate_mean_n_photons(
flux,
flux_per_photon,
max_sb,
):
"""Calculate the mean number of photons to shoot for photon shooting.
This routine can be used to group objects together by the typical number of photons
they will shoot when drawing objects in bulk.
Parameters:
flux: The flux of the GSObject (e.g., ``obj.flux``).
flux_per_photon: The flux per photon (e.g., ``obj._flux_per_photon``).
max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``).
Returns:
The number of photons.
"""
npd = _NPhotonsData(
n_photons=0.0,
poisson_flux=False,
max_extra_noise=0.0,
rng=None,
flux=flux,
flux_per_photon=flux_per_photon,
max_sb=max_sb,
)
return _sample_zero(npd)[0]
[docs]
@jax.jit
def calculate_n_photons(
flux,
flux_per_photon,
max_sb,
n_photons=0,
rng=None,
max_extra_noise=0.0,
poisson_flux=True,
):
"""Calculate the number of photons to shoot for an object when photon shooting according to the
code in ``galsim.GSObject._calculate_n_photons``. See the notes section below for more details.
Parameters:
flux: The flux of the GSObject (e.g., ``obj.flux``).
flux_per_photon: The flux per photon (e.g., ``obj._flux_per_photon``).
max_sb: The maximum surface brightness of the object (e.g., ``obj.max_sb``).
n_photons: If provided, the number of photons to use for photon shooting.
If not provided (i.e. ``n_photons = 0``), use as many photons as
necessary to result in an image with the correct Poisson shot
noise for the object's flux. For positive definite profiles, this
is equivalent to ``n_photons = flux``. However, some profiles need
more than this because some of the shot photons are negative
(usually due to interpolants). [default: 0]
rng: If provided, a random number generator to use for photon shooting,
which may be any kind of `BaseDeviate` object. If ``rng`` is None, one
will be automatically created, using the time as a seed.
[default: None]
max_extra_noise: If provided, the allowed extra noise in each pixel when photon
shooting. This is only relevant if ``n_photons=0``, so the number of
photons is being automatically calculated. In that case, if the image
noise is dominated by the sky background, then you can get away with
using fewer shot photons than the full ``n_photons = flux``.
Essentially each shot photon can have a ``flux > 1``, which increases
the noise in each pixel. The ``max_extra_noise`` parameter specifies
how much extra noise per pixel is allowed because of this approximation.
A typical value for this might be ``max_extra_noise = sky_level / 100``
where ``sky_level`` is the flux per pixel due to the sky. Note that
this uses a "variance" definition of noise, not a "sigma" definition.
[default: 0.]
poisson_flux: Whether to allow total object flux scaling to vary according to
Poisson statistics for ``n_photons`` samples when photon shooting.
[default: True, unless ``n_photons`` is given, in which case the default
is False]
Returns:
A tuple of ``(n_photons, g, rng)`` where ``n_photons`` is the number of photons, ``g`` is the flux ratio, and ``rng`` is the final random number generator used.
Notes:
It is easiest to look at the original code from ``GSObject._calculate_nphotons``
to understand what this function does:
.. code-block:: python
# Calculate how many photons to shoot and what flux_ratio (called g) each one should
# have in order to produce an image with the right S/N and total flux.
#
# This routine is normally called by drawPhot.
#
# Returns:
# n_photons, g
# For profiles that are positive definite, then N = flux. Easy.
#
# However, some profiles shoot some of their photons with negative flux. This means that
# we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the
# fraction of shot photons that have negative flux.
#
# S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2
# N^2 = Var(S) = (N+ + N-) = Ntot
#
# So flux = (S/N)^2 = Ntot (1-2eta)^2
# Ntot = flux / (1-2eta)^2
#
# However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta).
# So in fact, each photon needs to carry a flux of g = 1-2eta to get the right
# total flux.
#
# That's all the easy case. The trickier case is when we are sky-background dominated.
# Then we can usually get away with fewer shot photons than the above. In particular,
# if the noise from the photon shooting is much less than the sky noise, then we can
# use fewer shot photons and essentially have each photon have a flux > 1. This is ok
# as long as the additional noise due to this approximation is "much less than" the
# noise we'll be adding to the image for the sky noise.
#
# Let's still have Ntot photons, but now each with a flux of g. And let's look at the
# noise we get in the brightest pixel that has a nominal total flux of Imax.
#
# The number of photons hitting this pixel will be Imax/flux * Ntot.
# The variance of this number is the same thing (Poisson counting).
# So the noise in that pixel is:
#
# N^2 = Imax/flux * Ntot * g^2
#
# And the signal in that pixel will be:
#
# S = Imax/flux * (N+ - N-) * g which has to equal Imax, so
# g = flux / Ntot(1-2eta)
# N^2 = Imax/Ntot * flux / (1-2eta)^2
#
# As expected, we see that lowering Ntot will increase the noise in that (and every
# other) pixel.
# The input max_extra_noise parameter is the maximum value of spurious noise we want
# to allow.
#
# So setting N^2 = Imax + nu, we get
#
# Ntot = flux / (1-2eta)^2 / (1 + nu/Imax)
# g = (1 - 2eta) * (1 + nu/Imax)
#
# Returns the total flux placed inside the image bounds by photon shooting.
flux = self.flux
if flux == 0.0:
return 0, 1.0
# The _flux_per_photon property is (1-2eta)
# This factor will already be accounted for by the shoot function, so don't include
# that as part of our scaling here. There may be other adjustments though, so g=1 here.
eta_factor = self._flux_per_photon
mod_flux = flux / (eta_factor * eta_factor)
g = 1.0
# If requested, let the target flux value vary as a Poisson deviate
if poisson_flux:
# If we have both positive and negative photons, then the mix of these
# already gives us some variation in the flux value from the variance
# of how many are positive and how many are negative.
# The number of negative photons varies as a binomial distribution.
# <F-> = eta * Ntot * g
# <F+> = (1-eta) * Ntot * g
# <F+ - F-> = (1-2eta) * Ntot * g = flux
# Var(F-) = eta * (1-eta) * Ntot * g^2
# F+ = Ntot * g - F- is not an independent variable, so
# Var(F+ - F-) = Var(Ntot*g - 2*F-)
# = 4 * Var(F-)
# = 4 * eta * (1-eta) * Ntot * g^2
# = 4 * eta * (1-eta) * flux
# We want the variance to be equal to flux, so we need an extra:
# delta Var = (1 - 4*eta + 4*eta^2) * flux
# = (1-2eta)^2 * flux
absflux = abs(flux)
mean = eta_factor * eta_factor * absflux
pd = PoissonDeviate(rng, mean)
pd_val = pd() - mean + absflux
ratio = pd_val / absflux
g *= ratio
mod_flux *= ratio
if n_photons == 0.0:
n_photons = abs(mod_flux)
if max_extra_noise > 0.0:
gfactor = 1.0 + max_extra_noise / abs(self.max_sb)
n_photons /= gfactor
g *= gfactor
# Make n_photons an integer.
iN = int(n_photons + 0.5)
return iN, g
"""
n_photons_data = _NPhotonsData(
n_photons=n_photons,
poisson_flux=poisson_flux,
max_extra_noise=max_extra_noise,
rng=rng,
flux=flux,
flux_per_photon=flux_per_photon,
max_sb=max_sb,
)
_n_photons, g, _rng = jax.lax.cond(
n_photons_data.n_photons == 0.0,
_sample_zero,
_sample_nonzero,
n_photons_data,
)
if rng is not None:
rng._state = _rng._state
return _n_photons, g, rng
@jax.jit
def _sample_zero(n_photons_data):
_n_photons, _g, _rng = jax.lax.cond(
n_photons_data.flux == 0.0,
lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: (
0,
1.0,
rng,
),
lambda flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng: (
_calculate_n_photons_flux_nonzero(
flux, eta_factor, max_sb, poisson_flux, max_extra_noise, rng
)
),
n_photons_data.flux,
n_photons_data.flux_per_photon,
n_photons_data.max_sb,
n_photons_data.poisson_flux,
n_photons_data.max_extra_noise,
n_photons_data.rng,
)
if n_photons_data.rng is not None:
n_photons_data.rng._state = _rng._state
return _n_photons, _g, n_photons_data.rng
@jax.jit
def _sample_nonzero(n_photons_data):
g, _rng = jax.lax.cond(
n_photons_data.poisson_flux,
lambda n_photons_data: _sample_poisson_flux(
n_photons_data.flux, n_photons_data.flux_per_photon, n_photons_data.rng
),
lambda n_photons_data: (1.0, n_photons_data.rng),
n_photons_data,
)
if n_photons_data.rng is not None:
n_photons_data.rng._state = _rng._state
vals = jnp.int_(n_photons_data.n_photons + 0.5), g, n_photons_data.rng
return vals
@jax.jit
def _sample_poisson_flux(flux, eta_factor, rng):
absflux = jnp.abs(flux)
mean = eta_factor * eta_factor * absflux
pd = PoissonDeviate(rng, mean)
pd_val = pd() - mean + absflux
return pd_val / absflux, rng
def _adjust_flux_g_poisson(poisson_flux, flux, mod_flux, eta_factor, rng, g):
ratio, rng = _sample_poisson_flux(flux, eta_factor, rng)
g *= ratio
mod_flux *= ratio
return jnp.abs(mod_flux), g, rng
def _scale_extra_noise(max_extra_noise, mod_flux, g, max_sb):
gfactor = 1.0 + max_extra_noise / jnp.abs(max_sb)
mod_flux /= gfactor
g *= gfactor
return mod_flux, g
def _calculate_n_photons_flux_nonzero(
flux, flux_per_photon, max_sb, poisson_flux, max_extra_noise, rng
):
# For profiles that are positive definite, then N = flux. Easy.
#
# However, some profiles shoot some of their photons with negative flux. This means that
# we need a few more photons to get the right S/N = sqrt(flux). Take eta to be the
# fraction of shot photons that have negative flux.
#
# S^2 = (N+ - N-)^2 = (N+ + N- - 2N-)^2 = (Ntot - 2N-)^2 = Ntot^2(1 - 2 eta)^2
# N^2 = Var(S) = (N+ + N-) = Ntot
#
# So flux = (S/N)^2 = Ntot (1-2eta)^2
# Ntot = flux / (1-2eta)^2
#
# However, if each photon has a flux of 1, then S = (1-2eta) Ntot = flux / (1-2eta).
# So in fact, each photon needs to carry a flux of g = 1-2eta to get the right
# total flux.
#
# That's all the easy case. The trickier case is when we are sky-background dominated.
# Then we can usually get away with fewer shot photons than the above. In particular,
# if the noise from the photon shooting is much less than the sky noise, then we can
# use fewer shot photons and essentially have each photon have a flux > 1. This is ok
# as long as the additional noise due to this approximation is "much less than" the
# noise we'll be adding to the image for the sky noise.
#
# Let's still have Ntot photons, but now each with a flux of g. And let's look at the
# noise we get in the brightest pixel that has a nominal total flux of Imax.
#
# The number of photons hitting this pixel will be Imax/flux * Ntot.
# The variance of this number is the same thing (Poisson counting).
# So the noise in that pixel is:
#
# N^2 = Imax/flux * Ntot * g^2
#
# And the signal in that pixel will be:
#
# S = Imax/flux * (N+ - N-) * g which has to equal Imax, so
# g = flux / Ntot(1-2eta)
# N^2 = Imax/Ntot * flux / (1-2eta)^2
#
# As expected, we see that lowering Ntot will increase the noise in that (and every
# other) pixel.
# The input max_extra_noise parameter is the maximum value of spurious noise we want
# to allow.
#
# So setting N^2 = Imax + nu, we get
#
# Ntot = flux / (1-2eta)^2 / (1 + nu/Imax)
# g = (1 - 2eta) * (1 + nu/Imax)
#
# Returns the total flux placed inside the image bounds by photon shooting.
#
# The _flux_per_photon property is (1-2eta)
# This factor will already be accounted for by the shoot function, so don't include
# that as part of our scaling here. There may be other adjustments though, so g=1 here.
eta_factor = flux_per_photon
mod_flux = flux / (eta_factor * eta_factor)
g = 1.0
# If requested, let the target flux value vary as a Poisson deviate
mod_flux, g, _rng = jax.lax.cond(
poisson_flux,
lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: _adjust_flux_g_poisson(
poisson_flux, flux, mod_flux, eta_factor, rng, g
),
lambda poisson_flux, flux, mod_flux, eta_factor, rng, g: (mod_flux, g, rng),
poisson_flux,
flux,
mod_flux,
eta_factor,
rng,
g,
)
if rng is not None:
rng._state = _rng._state
mod_flux, g = jax.lax.cond(
max_extra_noise > 0.0,
lambda max_extra_noise, mod_flux, g, max_sb: _scale_extra_noise(
max_extra_noise, mod_flux, g, max_sb
),
lambda max_extra_noise, mod_flux, g, max_sb: (mod_flux, g),
max_extra_noise,
mod_flux,
g,
max_sb,
)
return jnp.ceil(mod_flux).astype(int), g, rng