Skip to content

Commit 1312493

Browse files
authored
Merge pull request #1096 from effigies/enh/dtype_aliases
ENH: Add static and dynamic dtype aliases to NIfTI images
2 parents a7e1e0e + 58d37a2 commit 1312493

File tree

2 files changed

+341
-9
lines changed

2 files changed

+341
-9
lines changed

nibabel/nifti1.py

+284-9
Original file line numberDiff line numberDiff line change
@@ -898,26 +898,28 @@ def set_data_dtype(self, datatype):
898898
>>> hdr.set_data_dtype(np.dtype(np.uint8))
899899
>>> hdr.get_data_dtype()
900900
dtype('uint8')
901-
>>> hdr.set_data_dtype('implausible') #doctest: +IGNORE_EXCEPTION_DETAIL
901+
>>> hdr.set_data_dtype('implausible')
902902
Traceback (most recent call last):
903903
...
904-
HeaderDataError: data dtype "implausible" not recognized
905-
>>> hdr.set_data_dtype('none') #doctest: +IGNORE_EXCEPTION_DETAIL
904+
nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized
905+
>>> hdr.set_data_dtype('none')
906906
Traceback (most recent call last):
907907
...
908-
HeaderDataError: data dtype "none" known but not supported
909-
>>> hdr.set_data_dtype(np.void) #doctest: +IGNORE_EXCEPTION_DETAIL
908+
nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported
909+
>>> hdr.set_data_dtype(np.void)
910910
Traceback (most recent call last):
911911
...
912-
HeaderDataError: data dtype "<type 'numpy.void'>" known but not supported
913-
>>> hdr.set_data_dtype('int') #doctest: +IGNORE_EXCEPTION_DETAIL
912+
nibabel.spatialimages.HeaderDataError: data dtype "<class 'numpy.void'>" known
913+
but not supported
914+
>>> hdr.set_data_dtype('int')
914915
Traceback (most recent call last):
915916
...
916917
ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
917-
>>> hdr.set_data_dtype(int) #doctest: +IGNORE_EXCEPTION_DETAIL
918+
>>> hdr.set_data_dtype(int)
918919
Traceback (most recent call last):
919920
...
920-
ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
921+
ValueError: Invalid data type <class 'int'>. Specify a sized integer, e.g., 'uint8' or
922+
numpy.int16.
921923
>>> hdr.set_data_dtype('int64')
922924
>>> hdr.get_data_dtype() == np.dtype('int64')
923925
True
@@ -1799,6 +1801,10 @@ class Nifti1Pair(analyze.AnalyzeImage):
17991801
_meta_sniff_len = header_class.sizeof_hdr
18001802
rw = True
18011803

1804+
# If a _dtype_alias has been set, it can only be resolved by inspecting
1805+
# the data at serialization time
1806+
_dtype_alias = None
1807+
18021808
def __init__(self, dataobj, affine, header=None,
18031809
extra=None, file_map=None, dtype=None):
18041810
# Special carve-out for 64 bit integers
@@ -2043,6 +2049,137 @@ def set_sform(self, affine, code=None, **kwargs):
20432049
else:
20442050
self._affine[:] = self._header.get_best_affine()
20452051

2052+
def set_data_dtype(self, datatype):
2053+
""" Set numpy dtype for data from code, dtype, type or alias
2054+
2055+
Using :py:class:`int` or ``"int"`` is disallowed, as these types
2056+
will be interpreted as ``np.int64``, which is almost never desired.
2057+
``np.int64`` is permitted for those intent on making poor choices.
2058+
2059+
The following aliases are defined to allow for flexible specification:
2060+
2061+
* ``'mask'`` - Alias for ``uint8``
2062+
* ``'compat'`` - The nearest Analyze-compatible datatype
2063+
(``uint8``, ``int16``, ``int32``, ``float32``)
2064+
* ``'smallest'`` - The smallest Analyze-compatible integer
2065+
(``uint8``, ``int16``, ``int32``)
2066+
2067+
Dynamic aliases are resolved when ``get_data_dtype()`` is called
2068+
with a ``finalize=True`` flag. Until then, these aliases are not
2069+
written to the header and will not persist to new images.
2070+
2071+
Examples
2072+
--------
2073+
>>> ints = np.arange(24, dtype='i4').reshape((2,3,4))
2074+
2075+
>>> img = Nifti1Image(ints, np.eye(4))
2076+
>>> img.set_data_dtype(np.uint8)
2077+
>>> img.get_data_dtype()
2078+
dtype('uint8')
2079+
>>> img.set_data_dtype('mask')
2080+
>>> img.get_data_dtype()
2081+
dtype('uint8')
2082+
>>> img.set_data_dtype('compat')
2083+
>>> img.get_data_dtype()
2084+
'compat'
2085+
>>> img.get_data_dtype(finalize=True)
2086+
dtype('<i4')
2087+
>>> img.get_data_dtype()
2088+
dtype('<i4')
2089+
>>> img.set_data_dtype('smallest')
2090+
>>> img.get_data_dtype()
2091+
'smallest'
2092+
>>> img.get_data_dtype(finalize=True)
2093+
dtype('uint8')
2094+
>>> img.get_data_dtype()
2095+
dtype('uint8')
2096+
2097+
Note that floating point values will not be coerced to ``int``
2098+
2099+
>>> floats = np.arange(24, dtype='f4').reshape((2,3,4))
2100+
>>> img = Nifti1Image(floats, np.eye(4))
2101+
>>> img.set_data_dtype('smallest')
2102+
>>> img.get_data_dtype(finalize=True)
2103+
Traceback (most recent call last):
2104+
...
2105+
ValueError: Cannot automatically cast array (of type float32) to an integer
2106+
type with fewer than 64 bits. Please set_data_dtype() to an explicit data type.
2107+
2108+
>>> arr = np.arange(1000, 1024, dtype='i4').reshape((2,3,4))
2109+
>>> img = Nifti1Image(arr, np.eye(4))
2110+
>>> img.set_data_dtype('smallest')
2111+
>>> img.set_data_dtype('implausible')
2112+
Traceback (most recent call last):
2113+
...
2114+
nibabel.spatialimages.HeaderDataError: data dtype "implausible" not recognized
2115+
>>> img.set_data_dtype('none')
2116+
Traceback (most recent call last):
2117+
...
2118+
nibabel.spatialimages.HeaderDataError: data dtype "none" known but not supported
2119+
>>> img.set_data_dtype(np.void)
2120+
Traceback (most recent call last):
2121+
...
2122+
nibabel.spatialimages.HeaderDataError: data dtype "<class 'numpy.void'>" known
2123+
but not supported
2124+
>>> img.set_data_dtype('int')
2125+
Traceback (most recent call last):
2126+
...
2127+
ValueError: Invalid data type 'int'. Specify a sized integer, e.g., 'uint8' or numpy.int16.
2128+
>>> img.set_data_dtype(int)
2129+
Traceback (most recent call last):
2130+
...
2131+
ValueError: Invalid data type <class 'int'>. Specify a sized integer, e.g., 'uint8' or
2132+
numpy.int16.
2133+
>>> img.set_data_dtype('int64')
2134+
>>> img.get_data_dtype() == np.dtype('int64')
2135+
True
2136+
"""
2137+
# Comparing dtypes to strings, numpy will attempt to call, e.g., dtype('mask'),
2138+
# so only check for aliases if the type is a string
2139+
# See https://github.com/numpy/numpy/issues/7242
2140+
if isinstance(datatype, str):
2141+
# Static aliases
2142+
if datatype == 'mask':
2143+
datatype = 'u1'
2144+
# Dynamic aliases
2145+
elif datatype in ('compat', 'smallest'):
2146+
self._dtype_alias = datatype
2147+
return
2148+
2149+
self._dtype_alias = None
2150+
super().set_data_dtype(datatype)
2151+
2152+
def get_data_dtype(self, finalize=False):
2153+
""" Get numpy dtype for data
2154+
2155+
If ``set_data_dtype()`` has been called with an alias
2156+
and ``finalize`` is ``False``, return the alias.
2157+
If ``finalize`` is ``True``, determine the appropriate dtype
2158+
from the image data object and set the final dtype in the
2159+
header before returning it.
2160+
"""
2161+
if self._dtype_alias is None:
2162+
return super().get_data_dtype()
2163+
if not finalize:
2164+
return self._dtype_alias
2165+
2166+
datatype = None
2167+
if self._dtype_alias == 'compat':
2168+
datatype = _get_analyze_compat_dtype(self._dataobj)
2169+
descrip = "an Analyze-compatible dtype"
2170+
elif self._dtype_alias == 'smallest':
2171+
datatype = _get_smallest_dtype(self._dataobj)
2172+
descrip = "an integer type with fewer than 64 bits"
2173+
else:
2174+
raise ValueError(f"Unknown dtype alias {self._dtype_alias}.")
2175+
if datatype is None:
2176+
dt = get_obj_dtype(self._dataobj)
2177+
raise ValueError(f"Cannot automatically cast array (of type {dt}) to {descrip}."
2178+
" Please set_data_dtype() to an explicit data type.")
2179+
2180+
self.set_data_dtype(datatype) # Clears the alias
2181+
return super().get_data_dtype()
2182+
20462183
def as_reoriented(self, ornt):
20472184
"""Apply an orientation change and return a new image
20482185
@@ -2136,3 +2273,141 @@ def save(img, filename):
21362273
Nifti1Image.instance_to_filename(img, filename)
21372274
except ImageFileError:
21382275
Nifti1Pair.instance_to_filename(img, filename)
2276+
2277+
2278+
def _get_smallest_dtype(
2279+
arr,
2280+
itypes=(np.uint8, np.int16, np.int32),
2281+
ftypes=(),
2282+
):
2283+
""" Return the smallest "sensible" dtype that will hold the array data
2284+
2285+
The purpose of this function is to support automatic type selection
2286+
for serialization, so "sensible" here means well-supported in the NIfTI-1 world.
2287+
2288+
For floating point data, select between single- and double-precision.
2289+
For integer data, select among uint8, int16 and int32.
2290+
2291+
The test is for min/max range, so float64 is pretty unlikely to be hit.
2292+
2293+
Returns ``None`` if these dtypes do not suffice.
2294+
2295+
>>> _get_smallest_dtype(np.array([0, 1]))
2296+
dtype('uint8')
2297+
>>> _get_smallest_dtype(np.array([-1, 1]))
2298+
dtype('int16')
2299+
>>> _get_smallest_dtype(np.array([0, 256]))
2300+
dtype('int16')
2301+
>>> _get_smallest_dtype(np.array([-65536, 65536]))
2302+
dtype('int32')
2303+
>>> _get_smallest_dtype(np.array([-2147483648, 2147483648]))
2304+
2305+
By default floating point types are not searched:
2306+
2307+
>>> _get_smallest_dtype(np.array([1.]))
2308+
>>> _get_smallest_dtype(np.array([2. ** 1000]))
2309+
>>> _get_smallest_dtype(np.longdouble(2) ** 2000)
2310+
>>> _get_smallest_dtype(np.array([1+0j]))
2311+
2312+
However, this function can be passed "legal" floating point types, and
2313+
the logic works the same.
2314+
2315+
>>> _get_smallest_dtype(np.array([1.]), ftypes=('float32',))
2316+
dtype('float32')
2317+
>>> _get_smallest_dtype(np.array([2. ** 1000]), ftypes=('float32',))
2318+
>>> _get_smallest_dtype(np.longdouble(2) ** 2000, ftypes=('float32',))
2319+
>>> _get_smallest_dtype(np.array([1+0j]), ftypes=('float32',))
2320+
"""
2321+
arr = np.asanyarray(arr)
2322+
if np.issubdtype(arr.dtype, np.floating):
2323+
test_dts = ftypes
2324+
info = np.finfo
2325+
elif np.issubdtype(arr.dtype, np.integer):
2326+
test_dts = itypes
2327+
info = np.iinfo
2328+
else:
2329+
return None
2330+
2331+
mn, mx = np.min(arr), np.max(arr)
2332+
for dt in test_dts:
2333+
dtinfo = info(dt)
2334+
if dtinfo.min <= mn and mx <= dtinfo.max:
2335+
return np.dtype(dt)
2336+
2337+
2338+
def _get_analyze_compat_dtype(arr):
2339+
""" Return an Analyze-compatible dtype that ``arr`` can be safely cast to
2340+
2341+
Analyze-compatible types are returned without inspection:
2342+
2343+
>>> _get_analyze_compat_dtype(np.uint8([0, 1]))
2344+
dtype('uint8')
2345+
>>> _get_analyze_compat_dtype(np.int16([0, 1]))
2346+
dtype('int16')
2347+
>>> _get_analyze_compat_dtype(np.int32([0, 1]))
2348+
dtype('int32')
2349+
>>> _get_analyze_compat_dtype(np.float32([0, 1]))
2350+
dtype('float32')
2351+
2352+
Signed ``int8`` are cast to ``uint8`` or ``int16`` based on value ranges:
2353+
2354+
>>> _get_analyze_compat_dtype(np.int8([0, 1]))
2355+
dtype('uint8')
2356+
>>> _get_analyze_compat_dtype(np.int8([-1, 1]))
2357+
dtype('int16')
2358+
2359+
Unsigned ``uint16`` are cast to ``int16`` or ``int32`` based on value ranges:
2360+
2361+
>>> _get_analyze_compat_dtype(np.uint16([32767]))
2362+
dtype('int16')
2363+
>>> _get_analyze_compat_dtype(np.uint16([65535]))
2364+
dtype('int32')
2365+
2366+
``int32`` is returned for integer types and ``float32`` for floating point types:
2367+
2368+
>>> _get_analyze_compat_dtype(np.array([-1, 1]))
2369+
dtype('int32')
2370+
>>> _get_analyze_compat_dtype(np.array([-1., 1.]))
2371+
dtype('float32')
2372+
2373+
If the value ranges exceed 4 bytes or cannot be cast, then a ``ValueError`` is raised:
2374+
2375+
>>> _get_analyze_compat_dtype(np.array([0, 4294967295]))
2376+
Traceback (most recent call last):
2377+
...
2378+
ValueError: Cannot find analyze-compatible dtype for array with dtype=int64
2379+
(min=0, max=4294967295)
2380+
2381+
>>> _get_analyze_compat_dtype([0., 2.e40])
2382+
Traceback (most recent call last):
2383+
...
2384+
ValueError: Cannot find analyze-compatible dtype for array with dtype=float64
2385+
(min=0.0, max=2e+40)
2386+
2387+
Note that real-valued complex arrays cannot be safely cast.
2388+
2389+
>>> _get_analyze_compat_dtype(np.array([1+0j]))
2390+
Traceback (most recent call last):
2391+
...
2392+
ValueError: Cannot find analyze-compatible dtype for array with dtype=complex128
2393+
(min=(1+0j), max=(1+0j))
2394+
"""
2395+
arr = np.asanyarray(arr)
2396+
dtype = arr.dtype
2397+
if dtype in (np.uint8, np.int16, np.int32, np.float32):
2398+
return dtype
2399+
2400+
if dtype == np.int8:
2401+
return np.dtype('uint8' if arr.min() >= 0 else 'int16')
2402+
elif dtype == np.uint16:
2403+
return np.dtype('int16' if arr.max() <= np.iinfo(np.int16).max else 'int32')
2404+
2405+
mn, mx = arr.min(), arr.max()
2406+
if np.can_cast(mn, np.int32) and np.can_cast(mx, np.int32):
2407+
return np.dtype('int32')
2408+
if np.can_cast(mn, np.float32) and np.can_cast(mx, np.float32):
2409+
return np.dtype('float32')
2410+
2411+
raise ValueError(
2412+
f"Cannot find analyze-compatible dtype for array with dtype={dtype} (min={mn}, max={mx})"
2413+
)

nibabel/tests/test_nifti1.py

+57
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,63 @@ def test_write_scaling(self):
11191119
with np.errstate(invalid='ignore'):
11201120
self._check_write_scaling(slope, inter, e_slope, e_inter)
11211121

1122+
def test_dynamic_dtype_aliases(self):
1123+
for in_dt, mn, mx, alias, effective_dt in [
1124+
(np.uint8, 0, 255, 'compat', np.uint8),
1125+
(np.int8, 0, 127, 'compat', np.uint8),
1126+
(np.int8, -128, 127, 'compat', np.int16),
1127+
(np.int16, -32768, 32767, 'compat', np.int16),
1128+
(np.uint16, 0, 32767, 'compat', np.int16),
1129+
(np.uint16, 0, 65535, 'compat', np.int32),
1130+
(np.int32, -2**31, 2**31-1, 'compat', np.int32),
1131+
(np.uint32, 0, 2**31-1, 'compat', np.int32),
1132+
(np.uint32, 0, 2**32-1, 'compat', None),
1133+
(np.int64, -2**31, 2**31-1, 'compat', np.int32),
1134+
(np.uint64, 0, 2**31-1, 'compat', np.int32),
1135+
(np.int64, 0, 2**32-1, 'compat', None),
1136+
(np.uint64, 0, 2**32-1, 'compat', None),
1137+
(np.float32, 0, 1e30, 'compat', np.float32),
1138+
(np.float64, 0, 1e30, 'compat', np.float32),
1139+
(np.float64, 0, 1e40, 'compat', None),
1140+
(np.int64, 0, 255, 'smallest', np.uint8),
1141+
(np.int64, 0, 256, 'smallest', np.int16),
1142+
(np.int64, -1, 255, 'smallest', np.int16),
1143+
(np.int64, 0, 32768, 'smallest', np.int32),
1144+
(np.int64, 0, 4294967296, 'smallest', None),
1145+
(np.float32, 0, 1, 'smallest', None),
1146+
(np.float64, 0, 1, 'smallest', None)
1147+
]:
1148+
arr = np.arange(24, dtype=in_dt).reshape((2, 3, 4))
1149+
arr[0, 0, :2] = [mn, mx]
1150+
img = self.image_class(arr, np.eye(4), dtype=alias)
1151+
# Stored as alias
1152+
assert img.get_data_dtype() == alias
1153+
if effective_dt is None:
1154+
with pytest.raises(ValueError):
1155+
img.get_data_dtype(finalize=True)
1156+
continue
1157+
# Finalizing sets and clears the alias
1158+
assert img.get_data_dtype(finalize=True) == effective_dt
1159+
assert img.get_data_dtype() == effective_dt
1160+
# Re-set to alias
1161+
img.set_data_dtype(alias)
1162+
assert img.get_data_dtype() == alias
1163+
img_rt = bytesio_round_trip(img)
1164+
assert img_rt.get_data_dtype() == effective_dt
1165+
# Seralizing does not finalize the source image
1166+
assert img.get_data_dtype() == alias
1167+
1168+
def test_static_dtype_aliases(self):
1169+
for alias, effective_dt in [
1170+
("mask", np.uint8),
1171+
]:
1172+
for orig_dt in ('u1', 'i8', 'f4'):
1173+
arr = np.arange(24, dtype=orig_dt).reshape((2, 3, 4))
1174+
img = self.image_class(arr, np.eye(4), dtype=alias)
1175+
assert img.get_data_dtype() == effective_dt
1176+
img_rt = bytesio_round_trip(img)
1177+
assert img_rt.get_data_dtype() == effective_dt
1178+
11221179

11231180
class TestNifti1Image(TestNifti1Pair):
11241181
# Run analyze-flavor spatialimage tests

0 commit comments

Comments
 (0)