Skip to content

Add get_image and get_image_label #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/astro_image_display_api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,28 @@ def test_save_overwrite(self, tmp_path):

# Using overwrite should save successfully
self.image.save(filename, overwrite=True)

def test_get_image_labels(self, data):
# the test viewer begins with a default empty image
assert len(self.image.get_image_labels()) == 1
assert self.image.get_image_labels()[0] is None
assert isinstance(self.image.get_image_labels(), tuple)

self.image.load_image(data, image_label="test")
assert len(self.image.get_image_labels()) == 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting; logically I would have thought of this as being 1.

Is this too complicated: None only counts as an image label if it has data associated with it

So the idea would be that if, like in this example, the only loaded data was for an image label you set, the get_image_labels returns just the image label you have used.

Doing this would return two image labels:

# Assume we are starting with a clean slate

assert self.image.get_image_labels() is None  # or maybe == []

# Load an image without a label, which means the label is `None`
self.image.load_image(data)

# Load a second image with an explicit label
self.image.load_image(data, label="test")

assert len(self.image.get_image_labels()) == 2  # None and "test"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking this in #74

assert self.image.get_image_labels()[-1] == "test"

def test_get_image(self, data):
self.image.load_image(data, image_label="test")

# currently the type is not specified in the API
assert self.image.get_image() is not None
assert self.image.get_image(image_label="test") is not None

retrieved_image = self.image.get_image(image_label="test")

self.image.load_image(retrieved_image, image_label="another test")
assert self.image.get_image(image_label="another test") is not None

with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_image(image_label="not a valid label")
15 changes: 15 additions & 0 deletions src/astro_image_display_api/image_viewer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ViewportInfo:
stretch: BaseStretch | None = None
cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None
colormap: str | None = None
data: ArrayLike | NDData | CCDData | None = None


@dataclass
Expand Down Expand Up @@ -332,6 +333,17 @@ def load_image(
# working with the new image.
self._wcs = self._images[image_label].wcs

def get_image(self, image_label: str | None = None):
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
f"Image label '{image_label}' not found. Please load an image first."
)
return self._images[image_label].data

def get_image_labels(self):
return tuple(self._images.keys())

def _determine_largest_dimension(self, shape: tuple[int, int]) -> int:
"""
Determine which index is the largest dimension.
Expand Down Expand Up @@ -401,6 +413,7 @@ def _initialize_image_viewport_stretch_cuts(
def _load_fits(self, file: str | os.PathLike, image_label: str | None) -> None:
ccd = CCDData.read(file)
self._images[image_label].wcs = ccd.wcs
self._images[image_label].data = ccd
self._initialize_image_viewport_stretch_cuts(ccd.data, image_label)

def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
Expand All @@ -416,6 +429,7 @@ def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
self._images[image_label].largest_dimension = self._determine_largest_dimension(
array.shape
)
self._images[image_label].data = array
self._initialize_image_viewport_stretch_cuts(array, image_label)

def _load_nddata(self, data: NDData, image_label: str | None) -> None:
Expand All @@ -428,6 +442,7 @@ def _load_nddata(self, data: NDData, image_label: str | None) -> None:
The NDData object to load.
"""
self._images[image_label].wcs = data.wcs
self._images[image_label].data = data
self._images[image_label].largest_dimension = self._determine_largest_dimension(
data.data.shape
)
Expand Down
79 changes: 65 additions & 14 deletions src/astro_image_display_api/interface_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,50 @@ def load_image(self, data: Any, image_label: str | None = None) -> None:
raise NotImplementedError

# Setting and getting image properties
@abstractmethod
def get_image(
self,
image_label: str | None = None,
) -> Any:
"""
Parameters
----------
image_label : optional
The label of the image to set the cuts for. If not given and there is
only one image loaded, that image is returned.

Returns
-------
image_data : Any
The data of the loaded image. The exact type of the data is not specified,
and different backends may return different types. A return type compatible
with `astropy.nddata.NDData` is preferred, but not required. It is expected
that the returned data can be re-loaded into the viewer using
`load_image`, however.

Raises
------
ValueError
If the ``image_label`` is not provided when there are multiple images
loaded, or if the ``image_label`` does not correspond to a loaded image.

"""
raise NotImplementedError

@abstractmethod
def get_image_labels(
self,
) -> tuple[str]:
"""
Get the labels of the loaded images.

Returns
-------
image_labels: tuple of str
The labels of the loaded images.
"""
raise NotImplementedError

@abstractmethod
def set_cuts(
self,
Expand Down Expand Up @@ -75,8 +119,8 @@ def set_cuts(
`astropy.visualization.BaseInterval` object.

ValueError
If the ``image_label`` is not provided when there are multiple images loaded,
or if the ``image_label`` does not correspond to a loaded image.
If the ``image_label`` is not provided when there are multiple images
loaded, or if the ``image_label`` does not correspond to a loaded image.

Notes
-----
Expand Down Expand Up @@ -105,8 +149,8 @@ def get_cuts(self, image_label: str | None = None) -> BaseInterval:
Raises
------
ValueError
If the ``image_label`` is not provided when there are multiple images loaded,
or if the ``image_label`` does not correspond to a loaded image.
If the ``image_label`` is not provided when there are multiple images
loaded, or if the ``image_label`` does not correspond to a loaded image.

Notes
-----
Expand All @@ -132,7 +176,8 @@ def set_stretch(self, stretch: BaseStretch, image_label: str | None = None) -> N
Raises
------
TypeError
If the ``stretch`` is not a valid `~astropy.visualization.BaseStretch` object.
If the ``stretch`` is not a valid `~astropy.visualization.BaseStretch`
object.

ValueError
If the ``image_label`` is not provided when there are multiple images loaded
Expand Down Expand Up @@ -291,7 +336,8 @@ def load_catalog(
name will be generated.
catalog_style : dict, optional
A dictionary that specifies the style of the markers used to
represent the catalog. See `~astro_image_display_api.interface_definition.ImageViewerInterface.set_catalog_style`
represent the catalog. See
`~astro_image_display_api.interface_definition.ImageViewerInterface.set_catalog_style`
for details.

Raises
Expand Down Expand Up @@ -497,15 +543,17 @@ def set_viewport(
Raises
------
TypeError
If the ``center`` is not a `~astropy.coordinates.SkyCoord` object or a tuple of floats, or if
the ``fov`` is not a angular `~astropy.units.Quantity` or a float, or if there is no WCS
and the center or field of view require a WCS to be applied.
If the ``center`` is not a `~astropy.coordinates.SkyCoord` object or a tuple
of floats, or if the ``fov`` is not a angular `~astropy.units.Quantity` or a
float, or if there is no WCS and the center or field of view require a WCS
to be applied.

ValueError
If ``image_label`` is not provided when there are multiple images loaded.

`astropy.units.UnitTypeError`
If the ``fov`` is a `~astropy.units.Quantity` but does not have an angular unit.
If the ``fov`` is a `~astropy.units.Quantity` but does not have an angular
unit.

Notes
-----
Expand All @@ -524,9 +572,11 @@ def get_viewport(
Parameters
----------
sky_or_pixel : str, optional
If 'sky', the center will be returned as a `~astropy.coordinates.SkyCoord` object.
If 'pixel', the center will be returned as a tuple of pixel coordinates.
If `None`, the default behavior is to return the center as a `~astropy.coordinates.SkyCoord` if
If 'sky', the center will be returned as a `~astropy.coordinates.SkyCoord`
object. If 'pixel', the center will be returned as a tuple of pixel
coordinates.
If `None`, the default behavior is to return the center as a
`~astropy.coordinates.SkyCoord` if
possible, or as a tuple of floats if the image is in pixel coordinates and
has no WCS information.
image_label : str, optional
Expand All @@ -539,7 +589,8 @@ def get_viewport(
dict
A dictionary containing the current viewport settings.
The keys are 'center', 'fov', and 'image_label'.
- 'center' is an `~astropy.coordinates.SkyCoord` object or a tuple of floats.
- 'center' is an `~astropy.coordinates.SkyCoord` object or a tuple of
floats.
- 'fov' is an `~astropy.units.Quantity` object or a float.
- 'image_label' is a string representing the label of the image.

Expand Down