Skip to content

Commit

Permalink
prune_twigs: fix mask when precise=True
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Sep 22, 2024
1 parent 06836a0 commit 031bc6b
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions navis/morpho/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def _prune_twigs_precise(
neuron: "core.TreeNeuron",
size: float,
inplace: bool = False,
mask: Optional[Union[Sequence[int], Callable]] = None,
recursive: Union[int, bool, float] = False,
) -> Optional[NeuronObject]:
"""Prune twigs using precise method."""
Expand All @@ -534,17 +535,38 @@ def _prune_twigs_precise(
res = tree.query_ball_point(neuron.leafs[["x", "y", "z"]].values, r=size)
candidates = neuron.nodes.node_id.values[np.unique(np.concatenate(res))]

if callable(mask):
mask = mask(neuron)

if mask is not None:
if mask.dtype == bool:
if len(mask) != neuron.n_nodes:
raise ValueError("Mask length must match number of nodes")
mask_nodes = neuron.nodes.node_id.values[mask]
elif mask.dtype in (int, np.int32, np.int64):
mask_nodes = mask
else:
raise TypeError(
f"Mask must be boolean or list of node IDs, got {mask.dtype}"
)

candidates = np.intersect1d(candidates, mask_nodes)

if not len(candidates):
return neuron

# For each node in neuron find out which leafs are directly distal to it
# `distal` is a matrix with all nodes in columns and leafs in rows
distal = graph.distal_to(neuron, a=leafs, b=candidates)

# Turn matrix into dictionary {'node': [leafs, distal, to, it]}
melted = distal.reset_index(drop=False).melt(id_vars="index")
melted = melted[melted.value]
melted.groupby("variable")["index"].apply(list)

# `distal` is now a dictionary for {'node_id': [leaf1, leaf2, ..], ..}
distal = melted.groupby("variable")["index"].apply(list).to_dict()

# For each node find the distance to any leaf - note we are using `length`
# For each node find the distance to any leaf - note we are using `size`
# as cutoff here
# `path_len` is a dict mapping {nodeA: {nodeB: length, ...}, ...}
# if nodeB is not in dictionary, it's not within reach
Expand All @@ -571,6 +593,12 @@ def _prune_twigs_precise(
# For each of the new leafs check their shortest distance to the
# original leafs to get the remainder
is_new_leaf = (neuron.nodes.type == "end").values

# If there is a mask, we have to exclude old leafs which would not have
# been in the mask
if mask is not None:
is_new_leaf = is_new_leaf & np.isin(neuron.nodes.node_id, mask_nodes)

new_leafs = neuron.nodes[is_new_leaf].node_id.values
max_len = [max([path_len[l1][l2] for l2 in distal[l1]]) for l1 in new_leafs]

Expand All @@ -581,7 +609,7 @@ def _prune_twigs_precise(
# Get vectors from leafs to their parents
nodes = neuron.nodes.set_index("node_id")
parents = nodes.loc[new_leafs, "parent_id"].values
loc1 = neuron.leafs[["x", "y", "z"]].values
loc1 = nodes.loc[new_leafs, ["x", "y", "z"]].values
loc2 = nodes.loc[parents, ["x", "y", "z"]].values
vec = loc1 - loc2
vec_len = np.linalg.norm(vec, axis=1)
Expand Down

0 comments on commit 031bc6b

Please sign in to comment.