diff --git a/src/astro_image_display_api/api_test.py b/src/astro_image_display_api/api_test.py index 63d3c1c..be70d69 100644 --- a/src/astro_image_display_api/api_test.py +++ b/src/astro_image_display_api/api_test.py @@ -797,3 +797,28 @@ def test_save_overwrite(self, tmp_path): # Using overwrite should save successfully self.image.save(filename, overwrite=True) + + def test_get_image_labels(self, data): + # the test viewer begins with a default empty image + assert len(self.image.get_image_labels()) == 1 + assert self.image.get_image_labels()[0] is None + assert isinstance(self.image.get_image_labels(), tuple) + + self.image.load_image(data, image_label="test") + assert len(self.image.get_image_labels()) == 2 + assert self.image.get_image_labels()[-1] == "test" + + def test_get_image(self, data): + self.image.load_image(data, image_label="test") + + # currently the type is not specified in the API + assert self.image.get_image() is not None + assert self.image.get_image(image_label="test") is not None + + retrieved_image = self.image.get_image(image_label="test") + + self.image.load_image(retrieved_image, image_label="another test") + assert self.image.get_image(image_label="another test") is not None + + with pytest.raises(ValueError, match="[Ii]mage label.*not found"): + self.image.get_image(image_label="not a valid label") diff --git a/src/astro_image_display_api/image_viewer_logic.py b/src/astro_image_display_api/image_viewer_logic.py index af4f390..fca90a6 100644 --- a/src/astro_image_display_api/image_viewer_logic.py +++ b/src/astro_image_display_api/image_viewer_logic.py @@ -50,6 +50,7 @@ class ViewportInfo: stretch: BaseStretch | None = None cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None colormap: str | None = None + data: ArrayLike | NDData | CCDData | None = None @dataclass @@ -332,6 +333,17 @@ def load_image( # working with the new image. self._wcs = self._images[image_label].wcs + def get_image(self, image_label: str | None = None): + image_label = self._resolve_image_label(image_label) + if image_label not in self._images: + raise ValueError( + f"Image label '{image_label}' not found. Please load an image first." + ) + return self._images[image_label].data + + def get_image_labels(self): + return tuple(self._images.keys()) + def _determine_largest_dimension(self, shape: tuple[int, int]) -> int: """ Determine which index is the largest dimension. @@ -401,6 +413,7 @@ def _initialize_image_viewport_stretch_cuts( def _load_fits(self, file: str | os.PathLike, image_label: str | None) -> None: ccd = CCDData.read(file) self._images[image_label].wcs = ccd.wcs + self._images[image_label].data = ccd self._initialize_image_viewport_stretch_cuts(ccd.data, image_label) def _load_array(self, array: ArrayLike, image_label: str | None) -> None: @@ -416,6 +429,7 @@ def _load_array(self, array: ArrayLike, image_label: str | None) -> None: self._images[image_label].largest_dimension = self._determine_largest_dimension( array.shape ) + self._images[image_label].data = array self._initialize_image_viewport_stretch_cuts(array, image_label) def _load_nddata(self, data: NDData, image_label: str | None) -> None: @@ -428,6 +442,7 @@ def _load_nddata(self, data: NDData, image_label: str | None) -> None: The NDData object to load. """ self._images[image_label].wcs = data.wcs + self._images[image_label].data = data self._images[image_label].largest_dimension = self._determine_largest_dimension( data.data.shape ) diff --git a/src/astro_image_display_api/interface_definition.py b/src/astro_image_display_api/interface_definition.py index bb23828..0315a3d 100644 --- a/src/astro_image_display_api/interface_definition.py +++ b/src/astro_image_display_api/interface_definition.py @@ -48,6 +48,50 @@ def load_image(self, data: Any, image_label: str | None = None) -> None: raise NotImplementedError # Setting and getting image properties + @abstractmethod + def get_image( + self, + image_label: str | None = None, + ) -> Any: + """ + Parameters + ---------- + image_label : optional + The label of the image to set the cuts for. If not given and there is + only one image loaded, that image is returned. + + Returns + ------- + image_data : Any + The data of the loaded image. The exact type of the data is not specified, + and different backends may return different types. A return type compatible + with `astropy.nddata.NDData` is preferred, but not required. It is expected + that the returned data can be re-loaded into the viewer using + `load_image`, however. + + Raises + ------ + ValueError + If the ``image_label`` is not provided when there are multiple images + loaded, or if the ``image_label`` does not correspond to a loaded image. + + """ + raise NotImplementedError + + @abstractmethod + def get_image_labels( + self, + ) -> tuple[str]: + """ + Get the labels of the loaded images. + + Returns + ------- + image_labels: tuple of str + The labels of the loaded images. + """ + raise NotImplementedError + @abstractmethod def set_cuts( self, @@ -75,8 +119,8 @@ def set_cuts( `astropy.visualization.BaseInterval` object. ValueError - If the ``image_label`` is not provided when there are multiple images loaded, - or if the ``image_label`` does not correspond to a loaded image. + If the ``image_label`` is not provided when there are multiple images + loaded, or if the ``image_label`` does not correspond to a loaded image. Notes ----- @@ -105,8 +149,8 @@ def get_cuts(self, image_label: str | None = None) -> BaseInterval: Raises ------ ValueError - If the ``image_label`` is not provided when there are multiple images loaded, - or if the ``image_label`` does not correspond to a loaded image. + If the ``image_label`` is not provided when there are multiple images + loaded, or if the ``image_label`` does not correspond to a loaded image. Notes ----- @@ -132,7 +176,8 @@ def set_stretch(self, stretch: BaseStretch, image_label: str | None = None) -> N Raises ------ TypeError - If the ``stretch`` is not a valid `~astropy.visualization.BaseStretch` object. + If the ``stretch`` is not a valid `~astropy.visualization.BaseStretch` + object. ValueError If the ``image_label`` is not provided when there are multiple images loaded @@ -291,7 +336,8 @@ def load_catalog( name will be generated. catalog_style : dict, optional A dictionary that specifies the style of the markers used to - represent the catalog. See `~astro_image_display_api.interface_definition.ImageViewerInterface.set_catalog_style` + represent the catalog. See + `~astro_image_display_api.interface_definition.ImageViewerInterface.set_catalog_style` for details. Raises @@ -497,15 +543,17 @@ def set_viewport( Raises ------ TypeError - If the ``center`` is not a `~astropy.coordinates.SkyCoord` object or a tuple of floats, or if - the ``fov`` is not a angular `~astropy.units.Quantity` or a float, or if there is no WCS - and the center or field of view require a WCS to be applied. + If the ``center`` is not a `~astropy.coordinates.SkyCoord` object or a tuple + of floats, or if the ``fov`` is not a angular `~astropy.units.Quantity` or a + float, or if there is no WCS and the center or field of view require a WCS + to be applied. ValueError If ``image_label`` is not provided when there are multiple images loaded. `astropy.units.UnitTypeError` - If the ``fov`` is a `~astropy.units.Quantity` but does not have an angular unit. + If the ``fov`` is a `~astropy.units.Quantity` but does not have an angular + unit. Notes ----- @@ -524,9 +572,11 @@ def get_viewport( Parameters ---------- sky_or_pixel : str, optional - If 'sky', the center will be returned as a `~astropy.coordinates.SkyCoord` object. - If 'pixel', the center will be returned as a tuple of pixel coordinates. - If `None`, the default behavior is to return the center as a `~astropy.coordinates.SkyCoord` if + If 'sky', the center will be returned as a `~astropy.coordinates.SkyCoord` + object. If 'pixel', the center will be returned as a tuple of pixel + coordinates. + If `None`, the default behavior is to return the center as a + `~astropy.coordinates.SkyCoord` if possible, or as a tuple of floats if the image is in pixel coordinates and has no WCS information. image_label : str, optional @@ -539,7 +589,8 @@ def get_viewport( dict A dictionary containing the current viewport settings. The keys are 'center', 'fov', and 'image_label'. - - 'center' is an `~astropy.coordinates.SkyCoord` object or a tuple of floats. + - 'center' is an `~astropy.coordinates.SkyCoord` object or a tuple of + floats. - 'fov' is an `~astropy.units.Quantity` object or a float. - 'image_label' is a string representing the label of the image.