Skip to content

Commit

Permalink
Better handled Spectrum1D images across classes
Browse files Browse the repository at this point in the history
  • Loading branch information
ojustino committed Sep 30, 2022
1 parent 78e68b7 commit ecce188
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 27 deletions.
17 changes: 11 additions & 6 deletions specreduce/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Background:
Parameters
----------
image : `~astropy.nddata.NDData` or array-like
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
Expand Down Expand Up @@ -59,7 +59,7 @@ def __post_init__(self):
Parameters
----------
image : `~astropy.nddata.NDData` or array-like
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
traces : List
list of trace objects (or integers to define FlatTraces) to
Expand All @@ -85,6 +85,11 @@ def _to_trace(trace):
raise ValueError('trace_object.trace_pos must be >= 1')
return trace

if isinstance(self.image, NDData):
# NOTE: should the NDData structure instead be preserved?
# (NDData includes Spectrum1D under its umbrella)
self.image = self.image.data

bkg_wimage = np.zeros_like(self.image, dtype=np.float64)
for trace in self.traces:
trace = _to_trace(trace)
Expand Down Expand Up @@ -132,7 +137,7 @@ def two_sided(cls, image, trace_object, separation, **kwargs):
Parameters
----------
image : nddata-compatible image
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
trace_object: Trace
estimated trace of the spectrum to center the background traces
Expand Down Expand Up @@ -165,7 +170,7 @@ def one_sided(cls, image, trace_object, separation, **kwargs):
Parameters
----------
image : nddata-compatible image
image : `~astropy.nddata.NDData`-like or array-like
image with 2-D spectral image data
trace_object: Trace
estimated trace of the spectrum to center the background traces
Expand All @@ -192,7 +197,7 @@ def bkg_image(self, image=None):
Parameters
----------
image : nddata-compatible image or None
image : `~astropy.nddata.NDData`-like, array-like, or None
image with 2-D spectral image data. If None, will use ``image`` passed
to extract the background.
Expand All @@ -211,7 +216,7 @@ def sub_image(self, image=None):
Parameters
----------
image : nddata-compatible image or None
image : `~astropy.nddata.NDData`-like, array-like, or None
image with 2-D spectral image data. If None, will use ``image`` passed
to extract the background.
Expand Down
45 changes: 27 additions & 18 deletions specreduce/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ class BoxcarExtract(SpecreduceOperation):
Parameters
----------
image : nddata-compatible image
image : `~astropy.nddata.NDData`-like or array-like, required
image with 2-D spectral image data
trace_object : Trace
trace_object : Trace, required
trace object
width : float
width : float, optional
width of extraction aperture in pixels
disp_axis : int
disp_axis : int, optional
dispersion axis
crossdisp_axis : int
crossdisp_axis : int, optional
cross-dispersion axis
Returns
Expand All @@ -150,15 +150,15 @@ def __call__(self, image=None, trace_object=None, width=None,
Parameters
----------
image : nddata-compatible image
image : `~astropy.nddata.NDData`-like or array-like, required
image with 2-D spectral image data
trace_object : Trace
trace_object : Trace, required
trace object
width : float
width : float, optional
width of extraction aperture in pixels [default: 5]
disp_axis : int
disp_axis : int, optional
dispersion axis [default: 1]
crossdisp_axis : int
crossdisp_axis : int, optional
cross-dispersion axis [default: 0]
Expand All @@ -174,25 +174,33 @@ def __call__(self, image=None, trace_object=None, width=None,
disp_axis = disp_axis if disp_axis is not None else self.disp_axis
crossdisp_axis = crossdisp_axis if crossdisp_axis is not None else self.crossdisp_axis

# handle image processing based on its type
if isinstance(image, Spectrum1D):
img = image.data
unit = image.unit
else:
img = image
unit = getattr(image, 'unit', u.DN)

# TODO: this check can be removed if/when implemented as a check in FlatTrace
if isinstance(trace_object, FlatTrace):
if trace_object.trace_pos < 1:
raise ValueError('trace_object.trace_pos must be >= 1')

# weight image to use for extraction
wimage = _ap_weight_image(
wimg = _ap_weight_image(
trace_object,
width,
disp_axis,
crossdisp_axis,
image.shape)
img.shape)

# extract
ext1d = np.sum(image * wimage, axis=crossdisp_axis)
ext1d = np.sum(img * wimg, axis=crossdisp_axis) * unit

# TODO: add wavelenght units, uncertainty and mask to spectrum1D object
spec = Spectrum1D(spectral_axis=np.arange(len(ext1d)) * u.pixel,
flux=ext1d * getattr(image, 'unit', u.DN))
# TODO: add wavelength units, uncertainty and mask to Spectrum1D object
pixels = np.arange(ext1d.shape[crossdisp_axis]) * u.pixel
spec = Spectrum1D(spectral_axis=pixels, flux=ext1d)

return spec

Expand All @@ -206,7 +214,7 @@ class HorneExtract(SpecreduceOperation):
Parameters
----------
image : `~astropy.nddata.NDData` or array-like, required
image : `~astropy.nddata.NDData`-like or array-like, required
The input 2D spectrum from which to extract a source. An
NDData object must specify uncertainty and a mask. An array
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
Expand Down Expand Up @@ -269,7 +277,7 @@ def __call__(self, image=None, trace_object=None,
Parameters
----------
image : `~astropy.nddata.NDData` or array-like, required
image : `~astropy.nddata.NDData`-like or array-like, required
The input 2D spectrum from which to extract a source. An
NDData object must specify uncertainty and a mask. An array
requires use of the ``variance``, ``mask``, & ``unit`` arguments.
Expand Down Expand Up @@ -322,6 +330,7 @@ def __call__(self, image=None, trace_object=None,

# handle image and associated data based on image's type
if isinstance(image, NDData):
# (NDData includes Spectrum1D under its umbrella)
img = np.ma.array(image.data, mask=image.mask)
unit = image.unit if image.unit is not None else u.Unit()

Expand Down
18 changes: 15 additions & 3 deletions specreduce/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from astropy.nddata import CCDData, NDData
from astropy.stats import gaussian_sigma_to_fwhm
from scipy.interpolate import UnivariateSpline
from specutils import Spectrum1D
import numpy as np

__all__ = ['Trace', 'FlatTrace', 'ArrayTrace', 'KosmosTrace']
Expand All @@ -20,15 +21,15 @@ class Trace:
Parameters
----------
image : `~astropy.nddata.CCDData`
image : `~astropy.nddata.NDData`-like or array-like, required
Image to be traced
Properties
----------
shape : tuple
Shape of the array describing the trace
"""
image: CCDData
image: NDData

def __post_init__(self):
self.trace_pos = self.image.shape[0] / 2
Expand All @@ -37,6 +38,11 @@ def __post_init__(self):
def __getitem__(self, i):
return self.trace[i]

def _parse_image(self):
if isinstance(self.image, Spectrum1D):
# NOTE: should the Spectrum1D structure instead be preserved?
self.image = self.image.data

@property
def shape(self):
return self.trace.shape
Expand Down Expand Up @@ -95,6 +101,8 @@ class FlatTrace(Trace):
trace_pos: float

def __post_init__(self):
super()._parse_image()

self.set_position(self.trace_pos)

def set_position(self, trace_pos):
Expand Down Expand Up @@ -124,6 +132,8 @@ class ArrayTrace(Trace):
trace: np.ndarray

def __post_init__(self):
super()._parse_image()

nx = self.image.shape[1]
nt = len(self.trace)
if nt != nx:
Expand Down Expand Up @@ -158,7 +168,7 @@ class KosmosTrace(Trace):
Parameters
----------
image : `~astropy.nddata.NDData` or array-like, required
image : `~astropy.nddata.NDData`-like or array-like, required
The image over which to run the trace. Assumes cross-dispersion
(spatial) direction is axis 0 and dispersion (wavelength)
direction is axis 1.
Expand Down Expand Up @@ -200,6 +210,8 @@ class KosmosTrace(Trace):
_disp_axis = 1

def __post_init__(self):
super()._parse_image()

# handle multiple image types and mask uncaught invalid values
if isinstance(self.image, NDData):
img = np.ma.masked_invalid(np.ma.masked_array(self.image.data,
Expand Down

0 comments on commit ecce188

Please sign in to comment.