diff --git a/octarine_navis_plugin/objects.py b/octarine_navis_plugin/objects.py index 187f076..4c7ce00 100644 --- a/octarine_navis_plugin/objects.py +++ b/octarine_navis_plugin/objects.py @@ -189,9 +189,21 @@ def connectors2gfx(neuron, neuron_color, object_id, **kwargs): if kwargs.get("cn_layout", None): cn_lay.update(kwargs.get("cn_layout", {})) + which_cn = kwargs.get('connectors', None) + if isinstance(which_cn, (list, np.ndarray, tuple)): + connectors = neuron.connectors[neuron.connectors.type.isin(which_cn)] + elif which_cn == 'pre': + connectors = neuron.presynapses + elif which_cn == 'post': + connectors = neuron.postsynapses + elif isinstance(which_cn, str): + connectors = neuron.connectors[neuron.connectors.type == which_cn] + else: + connectors = neuron.connectors + visuals = [] cn_colors = kwargs.get("cn_colors", None) - for j in neuron.connectors.type.unique(): + for j, this_cn in connectors.groupby('type'): if isinstance(cn_colors, dict): color = cn_colors.get(j, cn_lay.get(j, {}).get("color", (0.1, 0.1, 0.1))) elif cn_colors == "neuron": @@ -203,8 +215,6 @@ def connectors2gfx(neuron, neuron_color, object_id, **kwargs): color = navis.plotting.colors.eval_color(color, color_range=1) - this_cn = neuron.connectors[neuron.connectors.type == j] - pos = ( this_cn[["x", "y", "z"]] .apply(pd.to_numeric)