Skip to content

Commit

Permalink
accept radius='auto' and map colors from skeleton to mesh
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Sep 17, 2024
1 parent 63d8fed commit 6314074
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions octarine_navis_plugin/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,20 @@ def neuron2gfx(x, color=None, random_ids=False, **kwargs):
else:
object_id = neuron.id # this may also be a random ID

if kwargs.get("radius", False):
# Convert and carry connectors with us
if isinstance(neuron, navis.TreeNeuron):
_neuron = navis.conversion.tree2meshneuron(neuron)
_neuron.connectors = neuron.connectors
neuron = _neuron
if isinstance(neuron, navis.TreeNeuron) and kwargs.get('radius', False) == "auto":
# Number of nodes with radii
n_radii = (neuron.nodes.get("radius", pd.Series([])).fillna(0) > 0).sum()
# If less than 30% of nodes have a radius, we will fall back to lines
if n_radii / neuron.nodes.shape[0] < 0.3:
kwargs['radius'] = False

_neuron = navis.conversion.tree2meshneuron(neuron, warn_missing_radii=False)
_neuron.connectors = neuron.connectors
neuron = _neuron

# See if we need to map colors to vertices
if isinstance(colormap[i], np.ndarray) and colormap[i].ndim == 2:
colormap[i] = colormap[i][neuron.vertex_map]

neuron_color = colormap[i]
if not kwargs.get("connectors_only", False):
Expand Down Expand Up @@ -316,7 +324,9 @@ def skeleton2gfx(neuron, neuron_color, object_id, **kwargs):
if neuron_color.ndim == 1:
coords = navis.plotting.plot_utils.segments_to_coords(neuron)
else:
coords, neuron_color = navis.plotting.plot_utils.segments_to_coords(neuron, node_colors=neuron_color)
coords, neuron_color = navis.plotting.plot_utils.segments_to_coords(
neuron, node_colors=neuron_color
)
# `neuron_color` is now a list of colors for each segment; we have to flatten it
# and add `None` to match the breaks
neuron_color = np.vstack(
Expand Down Expand Up @@ -428,4 +438,4 @@ def skeletor2gfx(s, **kwargs):
import navis

s = navis.TreeNeuron(s, soma=None, id=getattr(s, "id", uuid.uuid4()))
return neuron2gfx(s, **kwargs)
return neuron2gfx(s, **kwargs)

0 comments on commit 6314074

Please sign in to comment.