diff --git a/fafbseg/flywire/synapses.py b/fafbseg/flywire/synapses.py index f59e47d..a0d0d4b 100644 --- a/fafbseg/flywire/synapses.py +++ b/fafbseg/flywire/synapses.py @@ -465,7 +465,8 @@ def get_synapses( "Querying synapse table with `filtered=True` already removes " "synaptic connections with cleft_score <= 50. If you want less " "confident connections set `filtered=False`. Note that this will " - "also drop the de-duplication (see docstring)." + "also drop the de-duplication! " + "See `help(fafbseg.flywire.get_synapses)` for details." ) navis.config.logger.warning(msg) @@ -674,8 +675,9 @@ def get_synapses( def get_adjacency( sources, targets=None, + square=True, materialization="auto", - neuropils=None, + neuropils=False, filtered=True, min_score=None, batch_size=1000, @@ -709,10 +711,14 @@ def get_adjacency( the root ID. If you have a neuron (in FlyWire space) but don't know its ID, use :func:`fafbseg.flywire.neuron_to_segments` first. If ``None``, will assume ```targets = sources``. - neuropils : str | list of str, optional - Provide neuropil (e.g. ``'AL_R'``) or list thereof (e.g. - ``['AL_R', 'AL_L']``) to filter connectivity to these ROIs. - Prefix neuropil with a tilde (e.g. ``~AL_R``) to exclude it. + square : bool, optional + Whether to return a square matrix. If False, will return + a DataFrame with sources as rows and targets as columns. + neuropils : bool | str | list of str, optional + - if True, will return neuropils. This requires `square=False`! + - a neuropil name (e.g. ``'AL_R'``) or list thereof (e.g. + ``['AL_R', 'AL_L']``) to filter connectivity to these ROIs. + - prefix neuropils with a tilde (e.g. ``~AL_R``) to exclude them. filtered : bool Whether to use the filtered synapse table. Briefly, this filter removes redundant and low confidence (<= 50 cleft score) @@ -767,6 +773,17 @@ def get_adjacency( 720575940631406673 5 """ + # This avoids some issues with asking "if neuropils" + if isinstance(neuropils, np.ndarray): + neuropils = neuropils.tolist() + + if isinstance(neuropils, bool) and neuropils and square: + raise ValueError( + "To return neuropils, `square` must be False. " + "Set `square=False` to return a DataFrame with sources as rows " + "and targets as columns." + ) + if isinstance(targets, type(None)): targets = sources @@ -798,7 +815,9 @@ def get_adjacency( else: _check_ids(both, materialization=materialization, dataset=dataset) - columns = ["pre_pt_root_id", "post_pt_root_id", "cleft_score", "id"] + columns = ["pre_pt_root_id", "post_pt_root_id", "cleft_score", "id"] + ( + ["neuropil"] if neuropils else [] + ) sv_cols = ["pre_pt_supervoxel_id", "post_pt_supervoxel_id"] if materialization == "live" and filtered: @@ -821,7 +840,13 @@ def get_adjacency( has_view = "valid_connection_v2" in client.materialize.get_views( materialization ) - no_np = isinstance(neuropils, type(None)) + # Check if we need to pull neuropils (in which case we can't pull connections + # from the view but have to pull individual synapses) + no_np = False + if isinstance(neuropils, type(None)): + no_np = True + elif isinstance(neuropils, bool): + no_np = not neuropils no_score_thresh = not min_score or min_score == 50 if has_view & no_np & no_score_thresh: columns = ["pre_pt_root_id", "post_pt_root_id", "n_syn"] @@ -879,13 +904,17 @@ def get_adjacency( # Combine results from batches if len(syn): syn = pd.concat(syn, axis=0, ignore_index=True) - else: + elif square: adj = pd.DataFrame( np.zeros((len(sources), len(targets))), index=sources, columns=targets ) adj.index.name = "source" adj.columns.name = "target" return adj + elif neuropils: + return pd.DataFrame([], columns=["pre", "post", "weight", "neuropil"]) + else: + return pd.DataFrame([], columns=["pre", "post", "weight"]) # Depending on how queries were batched, we need to drop duplicate synapses if "id" in syn.columns: @@ -896,7 +925,7 @@ def get_adjacency( ) # Subset to the desired neuropils - if not isinstance(neuropils, type(None)): + if isinstance(neuropils, (str, list, np.ndarray)): neuropils = make_iterable(neuropils) if len(neuropils): @@ -927,19 +956,26 @@ def get_adjacency( syn = syn[syn.cleft_score >= min_score] # Aggregate + # (if "weight" is in column then we're not dealing with individual synapses) if "weight" not in syn.columns: - cn = syn.groupby(["pre", "post"], as_index=False).size() + if isinstance(neuropils, bool) and neuropils: + cn = syn.groupby(["pre", "post", "neuropil"], as_index=False).size() + else: + cn = syn.groupby(["pre", "post"], as_index=False).size() else: cn = syn - cn.columns = ["source", "target", "weight"] + cn.columns = ["source", "target", "weight"] + (["neuropil"] if 'neuropil' in cn.columns else []) # Pivot - adj = cn.pivot(index="source", columns="target", values="weight").fillna(0) + if square: + adj = cn.pivot(index="source", columns="target", values="weight").fillna(0) - # Index to match order and add any missing neurons - adj = adj.reindex(index=sources, columns=targets).fillna(0) + # Index to match order and add any missing neurons + adj = adj.reindex(index=sources, columns=targets).fillna(0) - return adj + return adj + else: + return cn @parse_neuroncriteria()