diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 02e11e4f2..ddefd4995 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -2,6 +2,7 @@ import os import warnings +from pathlib import Path from .array_sequence import ArraySequence from .header import Field @@ -22,8 +23,8 @@ def is_supported(fileobj): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header) @@ -39,8 +40,8 @@ def detect_format(fileobj): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object pointing to a tractogram file (and ready to read from the beginning of the header) @@ -56,8 +57,8 @@ def detect_format(fileobj): except OSError: pass - if isinstance(fileobj, str): - _, ext = os.path.splitext(fileobj) + if isinstance(fileobj, (str, Path)): + ext = Path(fileobj).suffix return FORMATS.get(ext.lower()) return None @@ -68,8 +69,8 @@ def load(fileobj, lazy_load=False): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the streamlines file's header). lazy_load : {False, True}, optional @@ -106,7 +107,7 @@ def save(tractogram, filename, **kwargs): provided keyword arguments. If :class:`TractogramFile` object, the file format is known and will be used to save its content to `filename`. - filename : str + filename : path-like object Name of the file where the tractogram will be saved. \*\*kwargs : keyword arguments Keyword arguments passed to :class:`TractogramFile` constructor. diff --git a/nibabel/streamlines/tck.py b/nibabel/streamlines/tck.py index 358c57936..107a22478 100644 --- a/nibabel/streamlines/tck.py +++ b/nibabel/streamlines/tck.py @@ -112,8 +112,8 @@ def load(cls, fileobj, lazy_load=False): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object in + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object in binary mode pointing to TCK file (and ready to read from the beginning of the TCK header). Note that calling this function does not change the file position. @@ -167,8 +167,8 @@ def save(self, fileobj): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object in + fileobj : path-like or or file-like object + If path-like, a filename; otherwise an open file-like object in binary mode pointing to TCK file (and ready to write from the beginning of the TCK header data). """ @@ -403,8 +403,8 @@ def _read(cls, fileobj, header, buffer_size=4): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object in + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object in binary mode pointing to TCK file (and ready to read from the beginning of the TCK header). Note that calling this function does not change the file position. diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 8811ddcfa..9373cac7e 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -3,6 +3,7 @@ import warnings from io import BytesIO from os.path import join as pjoin +from pathlib import Path import numpy as np import pytest @@ -90,6 +91,7 @@ def test_is_supported_detect_format(tmp_path): assert not nib.streamlines.is_supported('') assert nib.streamlines.detect_format(f) is None assert nib.streamlines.detect_format('') is None + assert nib.streamlines.detect_format(Path('')) is None # Valid file without extension for tfile_cls in FORMATS.values(): @@ -128,6 +130,12 @@ def test_is_supported_detect_format(tmp_path): assert nib.streamlines.is_supported(f) assert nib.streamlines.detect_format(f) == tfile_cls + # Good extension, Path only + for ext, tfile_cls in FORMATS.items(): + f = Path('my_tractogram' + ext) + assert nib.streamlines.is_supported(f) + assert nib.streamlines.detect_format(f) == tfile_cls + # Extension should not be case-sensitive. for ext, tfile_cls in FORMATS.items(): f = 'my_tractogram' + ext.upper() @@ -149,7 +157,7 @@ def test_load_empty_file(self): with pytest.warns(Warning) if lazy_load else error_warnings(): assert_tractogram_equal(tfile.tractogram, DATA['empty_tractogram']) - def test_load_simple_file(self): + def test_load_simple_file_str(self): for lazy_load in [False, True]: for simple_filename in DATA['simple_filenames']: tfile = nib.streamlines.load(simple_filename, lazy_load=lazy_load) @@ -163,6 +171,20 @@ def test_load_simple_file(self): with pytest.warns(Warning) if lazy_load else error_warnings(): assert_tractogram_equal(tfile.tractogram, DATA['simple_tractogram']) + def test_load_simple_file_path(self): + for lazy_load in [False, True]: + for simple_filename in DATA['simple_filenames']: + tfile = nib.streamlines.load(Path(simple_filename), lazy_load=lazy_load) + assert isinstance(tfile, TractogramFile) + + if lazy_load: + assert type(tfile.tractogram), Tractogram + else: + assert type(tfile.tractogram), LazyTractogram + + with pytest.warns(Warning) if lazy_load else error_warnings(): + assert_tractogram_equal(tfile.tractogram, DATA['simple_tractogram']) + def test_load_complex_file(self): for lazy_load in [False, True]: for complex_filename in DATA['complex_filenames']: @@ -205,6 +227,11 @@ def test_save_tractogram_file(self): tfile = nib.streamlines.load('dummy.trk', lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) + with InTemporaryDirectory(): + nib.streamlines.save(trk_file, Path('dummy.trk')) + tfile = nib.streamlines.load('dummy.trk', lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + def test_save_empty_file(self): tractogram = Tractogram(affine_to_rasmm=np.eye(4)) for ext in FORMATS: diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index c434619d6..6fa212a8e 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -294,8 +294,8 @@ def load(cls, fileobj, lazy_load=False): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header). Note that calling this function does not change the file position. @@ -401,8 +401,8 @@ def save(self, fileobj): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object pointing to TRK file (and ready to write from the beginning of the TRK header data). """ @@ -550,8 +550,8 @@ def _read_header(fileobj): Parameters ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object + fileobj : path-like or file-like object + If path-like, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header). Note that calling this function does not change the file position.