Skip to content

Commit

Permalink
Merge pull request #1213 from effigies/typ/loadsave
Browse files Browse the repository at this point in the history
TYP: Annotate loadsave and image header classes
  • Loading branch information
effigies authored Apr 3, 2023
2 parents 40e31e8 + 9f189c6 commit ed95d8d
Show file tree
Hide file tree
Showing 15 changed files with 50 additions and 33 deletions.
5 changes: 2 additions & 3 deletions nibabel/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@
"""
from __future__ import annotations

from typing import Type

import numpy as np

from .arrayproxy import ArrayProxy
Expand Down Expand Up @@ -895,7 +893,8 @@ def may_contain_header(klass, binaryblock):
class AnalyzeImage(SpatialImage):
"""Class for basic Analyze format image"""

header_class: Type[AnalyzeHeader] = AnalyzeHeader
header_class: type[AnalyzeHeader] = AnalyzeHeader
header: AnalyzeHeader
_meta_sniff_len = header_class.sizeof_hdr
files_types: tuple[tuple[str, str], ...] = (('image', '.img'), ('header', '.hdr'))
valid_exts: tuple[str, ...] = ('.img', '.hdr')
Expand Down
1 change: 1 addition & 0 deletions nibabel/brikhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ class AFNIImage(SpatialImage):
"""

header_class = AFNIHeader
header: AFNIHeader
valid_exts = ('.brik', '.head')
files_types = (('image', '.brik'), ('header', '.head'))
_compressed_suffixes = ('.gz', '.bz2', '.Z', '.zst')
Expand Down
1 change: 1 addition & 0 deletions nibabel/cifti2/cifti2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,7 @@ class Cifti2Image(DataobjImage, SerializableImage):
"""Class for single file CIFTI-2 format image"""

header_class = Cifti2Header
header: Cifti2Header
valid_exts = Nifti2Image.valid_exts
files_types = Nifti2Image.files_types
makeable = False
Expand Down
2 changes: 1 addition & 1 deletion nibabel/ecat.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ class EcatImage(SpatialImage):
valid_exts = ('.v',)
files_types = (('image', '.v'), ('header', '.v'))

_header: EcatHeader
header: EcatHeader
_subheader: EcatSubHeader

ImageArrayProxy = EcatImageArrayProxy
Expand Down
4 changes: 1 addition & 3 deletions nibabel/filebasedimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import io
import typing as ty
from copy import deepcopy
from typing import Type
from urllib import request

from ._compression import COMPRESSION_ERRORS
Expand Down Expand Up @@ -158,8 +157,7 @@ class FileBasedImage:
work.
"""

header_class: Type[FileBasedHeader] = FileBasedHeader
_header: FileBasedHeader
header_class: type[FileBasedHeader] = FileBasedHeader
_meta_sniff_len: int = 0
files_types: tuple[ExtensionSpec, ...] = (('image', None),)
valid_exts: tuple[str, ...] = ()
Expand Down
1 change: 1 addition & 0 deletions nibabel/freesurfer/mghformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ class MGHImage(SpatialImage, SerializableImage):
"""Class for MGH format image"""

header_class = MGHHeader
header: MGHHeader
valid_exts = ('.mgh', '.mgz')
# Register that .mgz extension signals gzip compression
ImageOpener.compress_ext_map['.mgz'] = ImageOpener.gz_def
Expand Down
10 changes: 7 additions & 3 deletions nibabel/imageclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Define supported image classes and names"""
from __future__ import annotations

from .analyze import AnalyzeImage
from .brikhead import AFNIImage
from .cifti2 import Cifti2Image
from .dataobj_images import DataobjImage
from .filebasedimages import FileBasedImage
from .freesurfer import MGHImage
from .gifti import GiftiImage
from .minc1 import Minc1Image
Expand All @@ -21,7 +25,7 @@
from .spm99analyze import Spm99AnalyzeImage

# Ordered by the load/save priority.
all_image_classes = [
all_image_classes: list[type[FileBasedImage]] = [
Nifti1Pair,
Nifti1Image,
Nifti2Pair,
Expand All @@ -41,7 +45,7 @@
# Image classes known to require spatial axes to be first in index ordering.
# When adding an image class, consider whether the new class should be listed
# here.
KNOWN_SPATIAL_FIRST = (
KNOWN_SPATIAL_FIRST: tuple[type[FileBasedImage], ...] = (
Nifti1Pair,
Nifti1Image,
Nifti2Pair,
Expand All @@ -55,7 +59,7 @@
)


def spatial_axes_first(img):
def spatial_axes_first(img: DataobjImage) -> bool:
"""True if spatial image axes for `img` always precede other axes
Parameters
Expand Down
42 changes: 25 additions & 17 deletions nibabel/loadsave.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
# module imports
"""Utilities to load and save image objects"""
from __future__ import annotations

import os
import typing as ty

import numpy as np

Expand All @@ -22,7 +25,18 @@
_compressed_suffixes = ('.gz', '.bz2', '.zst')


def _signature_matches_extension(filename):
if ty.TYPE_CHECKING: # pragma: no cover
from .filebasedimages import FileBasedImage
from .filename_parser import FileSpec

P = ty.ParamSpec('P')

class Signature(ty.TypedDict):
signature: bytes
format_name: str


def _signature_matches_extension(filename: FileSpec) -> tuple[bool, str]:
"""Check if signature aka magic number matches filename extension.
Parameters
Expand All @@ -42,7 +56,7 @@ def _signature_matches_extension(filename):
the empty string otherwise.
"""
signatures = {
signatures: dict[str, Signature] = {
'.gz': {'signature': b'\x1f\x8b', 'format_name': 'gzip'},
'.bz2': {'signature': b'BZh', 'format_name': 'bzip2'},
'.zst': {'signature': b'\x28\xb5\x2f\xfd', 'format_name': 'ztsd'},
Expand All @@ -64,7 +78,7 @@ def _signature_matches_extension(filename):
return False, f'File {filename} is not a {format_name} file'


def load(filename, **kwargs):
def load(filename: FileSpec, **kwargs) -> FileBasedImage:
r"""Load file given filename, guessing at file type
Parameters
Expand Down Expand Up @@ -126,7 +140,7 @@ def guessed_image_type(filename):
raise ImageFileError(f'Cannot work out file type of "{filename}"')


def save(img, filename, **kwargs):
def save(img: FileBasedImage, filename: FileSpec, **kwargs) -> None:
r"""Save an image to file adapting format to `filename`
Parameters
Expand Down Expand Up @@ -161,19 +175,17 @@ def save(img, filename, **kwargs):
from .nifti1 import Nifti1Image, Nifti1Pair
from .nifti2 import Nifti2Image, Nifti2Pair

klass = None
converted = None

converted: FileBasedImage
if type(img) == Nifti1Image and lext in ('.img', '.hdr'):
klass = Nifti1Pair
converted = Nifti1Pair.from_image(img)
elif type(img) == Nifti2Image and lext in ('.img', '.hdr'):
klass = Nifti2Pair
converted = Nifti2Pair.from_image(img)
elif type(img) == Nifti1Pair and lext == '.nii':
klass = Nifti1Image
converted = Nifti1Image.from_image(img)
elif type(img) == Nifti2Pair and lext == '.nii':
klass = Nifti2Image
converted = Nifti2Image.from_image(img)
else: # arbitrary conversion
valid_klasses = [klass for klass in all_image_classes if ext in klass.valid_exts]
valid_klasses = [klass for klass in all_image_classes if lext in klass.valid_exts]
if not valid_klasses: # if list is empty
raise ImageFileError(f'Cannot work out file type of "{filename}"')

Expand All @@ -186,13 +198,9 @@ def save(img, filename, **kwargs):
break
except Exception as e:
err = e
# ... and if none of them work, raise an error.
if converted is None:
else:
raise err

# Here, we either have a klass or a converted image.
if converted is None:
converted = klass.from_image(img)
converted.to_filename(filename, **kwargs)


Expand Down
4 changes: 2 additions & 2 deletions nibabel/minc1.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from __future__ import annotations

from numbers import Integral
from typing import Type

import numpy as np

Expand Down Expand Up @@ -307,7 +306,8 @@ class Minc1Image(SpatialImage):
load.
"""

header_class: Type[MincHeader] = Minc1Header
header_class: type[MincHeader] = Minc1Header
header: MincHeader
_meta_sniff_len: int = 4
valid_exts: tuple[str, ...] = ('.mnc',)
files_types: tuple[tuple[str, str], ...] = (('image', '.mnc'),)
Expand Down
1 change: 1 addition & 0 deletions nibabel/minc2.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class Minc2Image(Minc1Image):
# MINC2 does not do compressed whole files
_compressed_suffixes = ()
header_class = Minc2Header
header: Minc2Header

@classmethod
def from_file_map(klass, file_map, *, mmap=True, keep_file_open=None):
Expand Down
8 changes: 4 additions & 4 deletions nibabel/nifti1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import warnings
from io import BytesIO
from typing import Type

import numpy as np
import numpy.linalg as npl
Expand Down Expand Up @@ -90,8 +89,8 @@
# datatypes not in analyze format, with codes
if have_binary128():
# Only enable 128 bit floats if we really have IEEE binary 128 longdoubles
_float128t: Type[np.generic] = np.longdouble
_complex256t: Type[np.generic] = np.longcomplex
_float128t: type[np.generic] = np.longdouble
_complex256t: type[np.generic] = np.longcomplex
else:
_float128t = np.void
_complex256t = np.void
Expand Down Expand Up @@ -1817,7 +1816,8 @@ class Nifti1PairHeader(Nifti1Header):
class Nifti1Pair(analyze.AnalyzeImage):
"""Class for NIfTI1 format image, header pair"""

header_class: Type[Nifti1Header] = Nifti1PairHeader
header_class: type[Nifti1Header] = Nifti1PairHeader
header: Nifti1Header
_meta_sniff_len = header_class.sizeof_hdr
rw = True

Expand Down
1 change: 1 addition & 0 deletions nibabel/parrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,7 @@ class PARRECImage(SpatialImage):
"""PAR/REC image"""

header_class = PARRECHeader
header: PARRECHeader
valid_exts = ('.rec', '.par')
files_types = (('image', '.rec'), ('header', '.par'))

Expand Down
1 change: 1 addition & 0 deletions nibabel/spatialimages.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ class SpatialImage(DataobjImage):
ImageSlicer: type[SpatialFirstSlicer] = SpatialFirstSlicer

_header: SpatialHeader
header: SpatialHeader

def __init__(
self,
Expand Down
1 change: 1 addition & 0 deletions nibabel/spm2analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Spm2AnalyzeImage(spm99.Spm99AnalyzeImage):
"""Class for SPM2 variant of basic Analyze image"""

header_class = Spm2AnalyzeHeader
header: Spm2AnalyzeHeader


load = Spm2AnalyzeImage.from_filename
Expand Down
1 change: 1 addition & 0 deletions nibabel/spm99analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ class Spm99AnalyzeImage(analyze.AnalyzeImage):
"""Class for SPM99 variant of basic Analyze image"""

header_class = Spm99AnalyzeHeader
header: Spm99AnalyzeHeader
files_types = (('image', '.img'), ('header', '.hdr'), ('mat', '.mat'))
has_affine = True
makeable = True
Expand Down

0 comments on commit ed95d8d

Please sign in to comment.