From a9ea1bd4a65b5466eed6c5dec5979d6c5370942e Mon Sep 17 00:00:00 2001 From: Erik Tollerud Date: Fri, 27 Jun 2025 11:14:55 -0700 Subject: [PATCH 1/2] change image shape to a method --- .../interface_definition.py | 34 +++++++++++++++---- 1 file changed, 28 insertions(+), 6 deletions(-) diff --git a/src/astro_image_display_api/interface_definition.py b/src/astro_image_display_api/interface_definition.py index c071f3d..32349e2 100644 --- a/src/astro_image_display_api/interface_definition.py +++ b/src/astro_image_display_api/interface_definition.py @@ -2,6 +2,7 @@ import os from abc import abstractmethod from typing import Any, Protocol, runtime_checkable +from collections import namedtuple from astropy.coordinates import SkyCoord from astropy.table import Table @@ -10,17 +11,13 @@ __all__ = [ "ImageViewerInterface", + "ImageShape", ] +ImageShape = namedtuple("ImageShape", ["width", "height"]) @runtime_checkable class ImageViewerInterface(Protocol): - # These are attributes, not methods. The type annotations are there - # to make sure Protocol knows they are attributes. Python does not - # do any checking at all of these types. - image_width: int - image_height: int - # The methods, grouped loosely by purpose # Method for loading image data @@ -201,6 +198,31 @@ def get_colormap(self, image_label: str | None = None) -> str: """ raise NotImplementedError + @abstractmethod + def get_shape(self, image_label: str | None = None) -> ImageShape: + """ + Get the shape (i.e., width and height in pixels) of the image. + + Parameters + ---------- + image_label : optional + The label of the image to get the shape for. If not given and there is + only one image loaded, the shape for that image is returned. If there are + multiple images and no label is provided, an error is raised. + + Returns + ------- + shape : `ImageShape`, a 2-tuple of ints + A named tuple containing the width and height of the image in pixels. + + Raises + ------ + ValueError + If the `image_label` is not provided when there are multiple images loaded, + or if the `image_label` does not correspond to a loaded image. + """ + raise NotImplementedError + # Saving contents of the view and accessing the view @abstractmethod def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None: From b4683cdfe2788f2340929831a40825fff5ab8fb0 Mon Sep 17 00:00:00 2001 From: Erik Tollerud Date: Fri, 27 Jun 2025 14:07:12 -0700 Subject: [PATCH 2/2] update tests for change to get_shape --- src/astro_image_display_api/api_test.py | 25 +++++------- .../image_viewer_logic.py | 40 +++++++++++++++++-- 2 files changed, 48 insertions(+), 17 deletions(-) diff --git a/src/astro_image_display_api/api_test.py b/src/astro_image_display_api/api_test.py index 589a0c2..79a64de 100644 --- a/src/astro_image_display_api/api_test.py +++ b/src/astro_image_display_api/api_test.py @@ -43,14 +43,17 @@ def wcs(self): return w @pytest.fixture - def catalog(self, wcs: WCS) -> Table: + def catalog(self, wcs: WCS, data: np.ndarray) -> Table: """ A catalog fixture that returns an empty table with the expected columns. """ + self.image.load_image(NDData(data=data, wcs=wcs)) + rng = np.random.default_rng(45328975) - x = rng.uniform(0, self.image.image_width, size=10) - y = rng.uniform(0, self.image.image_height, size=10) + + x = rng.uniform(0, self.image.get_shape().width, size=10) + y = rng.uniform(0, self.image.get_shape().height, size=10) coord = wcs.pixel_to_world(x, y) cat = Table( @@ -70,7 +73,7 @@ def setup(self): Subclasses MUST define ``image_widget_class`` -- doing so as a class variable does the trick. """ - self.image = self.image_widget_class(image_width=250, image_height=100) + self.image = self.image_widget_class() def _assert_empty_catalog_table(self, table): assert isinstance(table, Table) @@ -81,16 +84,10 @@ def _get_catalog_names_as_set(self): marks = self.image.get_catalog_names() return set(marks) - def test_width_height(self): - assert self.image.image_width == 250 - assert self.image.image_height == 100 - - width = 200 - height = 300 - self.image.image_width = width - self.image.image_height = height - assert self.image.image_width == width - assert self.image.image_height == height + def test_width_height(self, data: np.ndarray): + self.image.load_image(NDData(data=data)) + assert self.image.get_shape().width == 150 + assert self.image.get_shape().height == 100 @pytest.mark.parametrize("load_type", ["fits", "nddata", "array"]) def test_load(self, data, tmp_path, load_type): diff --git a/src/astro_image_display_api/image_viewer_logic.py b/src/astro_image_display_api/image_viewer_logic.py index 1bc10d8..029dfa3 100644 --- a/src/astro_image_display_api/image_viewer_logic.py +++ b/src/astro_image_display_api/image_viewer_logic.py @@ -22,7 +22,7 @@ from astropy.wcs.utils import proj_plane_pixel_scales from numpy.typing import ArrayLike -from .interface_definition import ImageViewerInterface +from .interface_definition import ImageViewerInterface, ImageShape @dataclass @@ -48,6 +48,8 @@ class ViewportInfo: stretch: BaseStretch | None = None cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None colormap: str | None = None + image_width: int | None = None + image_height: int | None = None @dataclass @@ -60,8 +62,6 @@ class ImageViewerLogic: # These are attributes, not methods. The type annotations are there # to make sure Protocol knows they are attributes. Python does not # do any checking at all of these types. - image_width: int = 0 - image_height: int = 0 zoom_level: float = 1 _cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval( upper_percentile=95 @@ -198,6 +198,37 @@ def get_colormap(self, image_label: str | None = None) -> str: # The methods, grouped loosely by purpose + + def get_shape(self, image_label: str | None = None) -> ImageShape: + """ + Get the shape (i.e., width and height in pixels) of the image. + + Parameters + ---------- + image_label : optional + The label of the image to get the shape for. If not given and there is + only one image loaded, the shape for that image is returned. If there are + multiple images and no label is provided, an error is raised. + + Returns + ------- + shape : `ImageShape`, a 2-tuple of ints + A named tuple containing the width and height of the image in pixels. + + Raises + ------ + ValueError + If the `image_label` is not provided when there are multiple images loaded, + or if the `image_label` does not correspond to a loaded image. + """ + image_label = self._resolve_image_label(image_label) + if image_label not in self._images: + raise ValueError( + f"Image label '{image_label}' not found. Please load an image first." + ) + img = self._images[image_label] + return ImageShape(img.image_width, img.image_height) + def get_catalog_style(self, catalog_label=None) -> dict[str, Any]: """ Get the style for the catalog. @@ -370,6 +401,9 @@ def _initialize_image_viewport_stretch_cuts( # Deal with the viewport first height, width = image_data.shape + self._images[image_label].image_width = width + self._images[image_label].image_height = height + # Center the image in the viewport and show the whole image. center = (width / 2, height / 2) fov = max(image_data.shape)