Skip to content

Commit

Permalink
nrrd_io.py: formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Sep 27, 2024
1 parent bb3efca commit ce7fe74
Showing 1 changed file with 95 additions and 68 deletions.
163 changes: 95 additions & 68 deletions navis/io/nrrd_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
logger = config.get_logger(__name__)


def write_nrrd(x: 'core.NeuronObject',
filepath: Union[str, Path],
compression_level: int = 3,
attrs: Optional[Dict[str, Any]] = None) -> None:
def write_nrrd(
x: "core.NeuronObject",
filepath: Union[str, Path],
compression_level: int = 3,
attrs: Optional[Dict[str, Any]] = None,
) -> None:
"""Write VoxelNeurons or Dotprops to NRRD file(s).
Parameters
Expand Down Expand Up @@ -106,43 +108,43 @@ def write_nrrd(x: 'core.NeuronObject',
compression_level = int(compression_level)

if (compression_level < 1) or (compression_level > 9):
raise ValueError('`compression_level` must be 1-9, got '
f'{compression_level}')
raise ValueError("`compression_level` must be 1-9, got " f"{compression_level}")

writer = base.Writer(_write_nrrd, ext='.nrrd')
writer = base.Writer(_write_nrrd, ext=".nrrd")

return writer.write_any(x,
filepath=filepath,
compression_level=compression_level,
**(attrs or {}))
return writer.write_any(
x, filepath=filepath, compression_level=compression_level, **(attrs or {})
)


def _write_nrrd(x: Union['core.VoxelNeuron', 'core.Dotprops'],
filepath: Optional[str] = None,
compression_level: int = 1,
**attrs) -> None:
def _write_nrrd(
x: Union["core.VoxelNeuron", "core.Dotprops"],
filepath: Optional[str] = None,
compression_level: int = 1,
**attrs,
) -> None:
"""Write single neuron to NRRD file."""
if not isinstance(x, (core.VoxelNeuron, core.Dotprops)):
raise TypeError(f'Expected VoxelNeuron or Dotprops, got "{type(x)}"')

header = getattr(x, "nrrd_header", {})
header['space dimension'] = 3
header['space directions'] = np.diag(x.units_xyz.magnitude)
header['space units'] = [str(x.units_xyz.units)] * 3
header["space dimension"] = 3
header["space directions"] = np.diag(x.units_xyz.magnitude)
header["space units"] = [str(x.units_xyz.units)] * 3
header.update(attrs or {})

if isinstance(x, core.VoxelNeuron):
data = x.grid
if data.dtype == bool:
data = data.astype('uint8')
data = data.astype("uint8")
else:
# For dotprops make a horizontal stack from points + vectors
data = np.hstack((x.points, x.vect))
header['k'] = x.k
header["k"] = x.k

nrrd.write(str(filepath),
data=data,
header=header,
nrrd.write(
str(filepath), data=data, header=header, compression_level=compression_level
)


def read_nrrd(
Expand Down Expand Up @@ -260,34 +262,53 @@ def read_nrrd(
n_cores = int(parallel)

with mp.Pool(processes=n_cores) as pool:
results = pool.imap(_worker_wrapper, [dict(f=x,
threshold=threshold,
output=output,
errors=errors,
include_subdirs=include_subdirs,
parallel=False) for x in f],
chunksize=1)

res = list(config.tqdm(results,
desc='Importing',
total=len(f),
disable=config.pbar_hide,
leave=config.pbar_leave))
results = pool.imap(
_worker_wrapper,
[
dict(
f=x,
threshold=threshold,
output=output,
errors=errors,
include_subdirs=include_subdirs,
parallel=False,
)
for x in f
],
chunksize=1,
)

res = list(
config.tqdm(
results,
desc="Importing",
total=len(f),
disable=config.pbar_hide,
leave=config.pbar_leave,
)
)

else:
# If not parallel just import the good 'ole way: sequentially
res = [read_nrrd(x,
threshold=threshold,
include_subdirs=include_subdirs,
output=output,
errors=errors,
parallel=parallel,
**kwargs)
for x in config.tqdm(f, desc='Importing',
disable=config.pbar_hide,
leave=config.pbar_leave)]

if output == 'raw':
res = [
read_nrrd(
x,
threshold=threshold,
include_subdirs=include_subdirs,
output=output,
errors=errors,
parallel=parallel,
**kwargs,
)
for x in config.tqdm(
f,
desc="Importing",
disable=config.pbar_hide,
leave=config.pbar_leave,
)
]

if output == "raw":
return [r[0] for r in res], [r[1] for r in res]

return core.NeuronList([r for r in res if r])
Expand All @@ -314,19 +335,19 @@ def read_nrrd(
units = None
su = None
voxdim = np.array([1, 1, 1])
if 'space directions' in header:
sd = np.asarray(header['space directions'])
if "space directions" in header:
sd = np.asarray(header["space directions"])
if sd.ndim == 2:
voxdim = np.diag(sd)[:3]
if 'space units' in header:
su = header['space units']
if "space units" in header:
su = header["space units"]
if len(su) == 3:
units = [f'{m} {u}' for m, u in zip(voxdim, su)]
units = [f"{m} {u}" for m, u in zip(voxdim, su)]
else:
units = voxdim

try:
if output == 'dotprops':
if output == "dotprops":
# If we're trying to get voxels from an image
if data.ndim == 3:
if threshold:
Expand Down Expand Up @@ -356,30 +377,36 @@ def read_nrrd(
elif data.shape[1] == 7:
points, vect, alpha = data[:, :3], data[:, 3:6], data[:, 6]
else:
raise ValueError('Expected data to be either (N, 3), (N, 6) '
f'or (N, 7) but NRRD file contains {data.shape}')
raise ValueError(
"Expected data to be either (N, 3), (N, 6) "
f"or (N, 7) but NRRD file contains {data.shape}"
)
# Get `k` either from provided kwargs or the file's header
k = kwargs.pop('k', header.get('k', 20))
k = kwargs.pop("k", header.get("k", 20))

x = core.Dotprops(points, k=k, vect=vect, alpha=alpha, **kwargs)
else:
raise ValueError('Data must be 2- or 3-dimensional to extract '
f'Dotprops, got {data.ndim}')
raise ValueError(
"Data must be 2- or 3-dimensional to extract "
f"Dotprops, got {data.ndim}"
)

if su and len(su) == 3:
x.units = [f'1 {s}' for s in su]
x.units = [f"1 {s}" for s in su]
else:
if data.ndim == 2:
logger.warning(f'Data in NRRD file is of shape {data.shape} - '
'i.e. 2D. Could this be a point cloud/dotprops '
'instead of voxels?')
logger.warning(
f"Data in NRRD file is of shape {data.shape} - "
"i.e. 2D. Could this be a point cloud/dotprops "
"instead of voxels?"
)
x = core.VoxelNeuron(data, units=units)
except BaseException as e:
msg = f'Error converting file {fname} to neuron.'
if errors == 'raise':
raise ImportError(msg) from e
elif errors == 'log':
logger.error(f'{msg}: {e}')
msg = f"Error converting file {fname} to neuron."
if errors == "raise":
raise e
elif errors == "log":
logger.error(f"{msg}: {e}")
return

# Add some additional properties
Expand Down

0 comments on commit ce7fe74

Please sign in to comment.