Skip to content

covert shape-getting to method #70

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions src/astro_image_display_api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,17 @@ def wcs(self):
return w

@pytest.fixture
def catalog(self, wcs: WCS) -> Table:
def catalog(self, wcs: WCS, data: np.ndarray) -> Table:
"""
A catalog fixture that returns an empty table with the
expected columns.
"""
self.image.load_image(NDData(data=data, wcs=wcs))

rng = np.random.default_rng(45328975)
x = rng.uniform(0, self.image.image_width, size=10)
y = rng.uniform(0, self.image.image_height, size=10)

x = rng.uniform(0, self.image.get_shape().width, size=10)
y = rng.uniform(0, self.image.get_shape().height, size=10)
coord = wcs.pixel_to_world(x, y)

cat = Table(
Expand All @@ -70,7 +73,7 @@ def setup(self):
Subclasses MUST define ``image_widget_class`` -- doing so as a
class variable does the trick.
"""
self.image = self.image_widget_class(image_width=250, image_height=100)
self.image = self.image_widget_class()

def _assert_empty_catalog_table(self, table):
assert isinstance(table, Table)
Expand All @@ -81,16 +84,10 @@ def _get_catalog_names_as_set(self):
marks = self.image.get_catalog_names()
return set(marks)

def test_width_height(self):
assert self.image.image_width == 250
assert self.image.image_height == 100

width = 200
height = 300
self.image.image_width = width
self.image.image_height = height
assert self.image.image_width == width
assert self.image.image_height == height
def test_width_height(self, data: np.ndarray):
self.image.load_image(NDData(data=data))
assert self.image.get_shape().width == 150
assert self.image.get_shape().height == 100
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please also add a test that get_shape returns an ImageShape?


@pytest.mark.parametrize("load_type", ["fits", "nddata", "array"])
def test_load(self, data, tmp_path, load_type):
Expand Down
40 changes: 37 additions & 3 deletions src/astro_image_display_api/image_viewer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from astropy.wcs.utils import proj_plane_pixel_scales
from numpy.typing import ArrayLike

from .interface_definition import ImageViewerInterface
from .interface_definition import ImageViewerInterface, ImageShape


@dataclass
Expand All @@ -48,6 +48,8 @@ class ViewportInfo:
stretch: BaseStretch | None = None
cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None
colormap: str | None = None
image_width: int | None = None
image_height: int | None = None


@dataclass
Expand All @@ -60,8 +62,6 @@ class ImageViewerLogic:
# These are attributes, not methods. The type annotations are there
# to make sure Protocol knows they are attributes. Python does not
# do any checking at all of these types.
image_width: int = 0
image_height: int = 0
zoom_level: float = 1
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(
upper_percentile=95
Expand Down Expand Up @@ -198,6 +198,37 @@ def get_colormap(self, image_label: str | None = None) -> str:

# The methods, grouped loosely by purpose


def get_shape(self, image_label: str | None = None) -> ImageShape:
"""
Get the shape (i.e., width and height in pixels) of the image.

Parameters
----------
image_label : optional
The label of the image to get the shape for. If not given and there is
only one image loaded, the shape for that image is returned. If there are
multiple images and no label is provided, an error is raised.

Returns
-------
shape : `ImageShape`, a 2-tuple of ints
A named tuple containing the width and height of the image in pixels.

Raises
------
ValueError
If the `image_label` is not provided when there are multiple images loaded,
or if the `image_label` does not correspond to a loaded image.
"""
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
f"Image label '{image_label}' not found. Please load an image first."
)
img = self._images[image_label]
return ImageShape(img.image_width, img.image_height)

def get_catalog_style(self, catalog_label=None) -> dict[str, Any]:
"""
Get the style for the catalog.
Expand Down Expand Up @@ -370,6 +401,9 @@ def _initialize_image_viewport_stretch_cuts(

# Deal with the viewport first
height, width = image_data.shape
self._images[image_label].image_width = width
self._images[image_label].image_height = height

# Center the image in the viewport and show the whole image.
center = (width / 2, height / 2)
fov = max(image_data.shape)
Expand Down
34 changes: 28 additions & 6 deletions src/astro_image_display_api/interface_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from abc import abstractmethod
from typing import Any, Protocol, runtime_checkable
from collections import namedtuple

from astropy.coordinates import SkyCoord
from astropy.table import Table
Expand All @@ -10,17 +11,13 @@

__all__ = [
"ImageViewerInterface",
"ImageShape",
]

ImageShape = namedtuple("ImageShape", ["width", "height"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking out loud here...is there any way to avoid declaring ImageShape as something that needs to be imported by backends? For type hinting we could have get_shape(..) -> tuple[int, int] and then have get_shape return an ImgeShape that is defined inside of get_shape.

Or, I suppose, we could follow the Array API specification and just return a plain tuple...

Not against named tuples, but trying to think of ways to minimize the imports for backends.


@runtime_checkable
class ImageViewerInterface(Protocol):
# These are attributes, not methods. The type annotations are there
# to make sure Protocol knows they are attributes. Python does not
# do any checking at all of these types.
image_width: int
image_height: int

# The methods, grouped loosely by purpose

# Method for loading image data
Expand Down Expand Up @@ -201,6 +198,31 @@ def get_colormap(self, image_label: str | None = None) -> str:
"""
raise NotImplementedError

@abstractmethod
def get_shape(self, image_label: str | None = None) -> ImageShape:
"""
Get the shape (i.e., width and height in pixels) of the image.

Parameters
----------
image_label : optional
The label of the image to get the shape for. If not given and there is
only one image loaded, the shape for that image is returned. If there are
multiple images and no label is provided, an error is raised.

Returns
-------
shape : `ImageShape`, a 2-tuple of ints
A named tuple containing the width and height of the image in pixels.

Raises
------
ValueError
If the `image_label` is not provided when there are multiple images loaded,
or if the `image_label` does not correspond to a loaded image.
"""
raise NotImplementedError

# Saving contents of the view and accessing the view
@abstractmethod
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
Expand Down