diff --git a/fafbseg/flywire/synapses.py b/fafbseg/flywire/synapses.py index 6ecc5dc..538b821 100644 --- a/fafbseg/flywire/synapses.py +++ b/fafbseg/flywire/synapses.py @@ -805,14 +805,24 @@ def get_adjacency( _check_ids(both, materialization=materialization, dataset=dataset) columns = ["pre_pt_root_id", "post_pt_root_id", "cleft_score", "id"] + sv_cols = ["pre_pt_supervoxel_id", "post_pt_supervoxel_id"] - if materialization == "live": + if materialization == "live" and filtered: + raise ValueError( + "It is currently not possible to fetch filtered " + "synapses in live queries. You can set `filtered=False` " + "but please be aware that this will query the " + "unfiltered synapse table. See docs for details." + ) + elif materialization == "live": func = partial( retry(client.materialize.live_query), table=client.materialize.synapse_table, timestamp=dt.datetime.utcnow(), - select_columns=columns, - ) + # nb there is a bug in CAVE which causes empty results if we don't + # ask for supervoxels + select_columns=columns + sv_cols, + ) elif filtered: has_view = "valid_connection_v2" in client.materialize.get_views( materialization @@ -858,7 +868,7 @@ def get_adjacency( for k in range(0, len(targets), batch_size): target_batch = targets[k : k + batch_size] - if not filtered: + if not filtered or materialization == "live": filter_in_dict = dict( post_pt_root_id=target_batch, pre_pt_root_id=source_batch ) @@ -868,15 +878,16 @@ def get_adjacency( post_pt_root_id=target_batch, pre_pt_root_id=source_batch ) ) - - this = func(filter_in_dict=filter_in_dict) + this = func(filter_in_dict=filter_in_dict) # We need to drop the .attrs (which contain meta data from queries) # Otherwise we run into issues when concatenating this.attrs = {} if not this.empty: - syn.append(this) + syn.append(this.drop( + sv_cols, axis=1, errors="ignore" + )) # Combine results from batches if len(syn): @@ -1106,6 +1117,7 @@ def get_connectivity( _check_ids(ids, materialization=materialization, dataset=dataset) columns = ["pre_pt_root_id", "post_pt_root_id", "cleft_score", "id"] + sv_cols = ["pre_pt_supervoxel_id", "post_pt_supervoxel_id"] if transmitters: columns += ["gaba", "ach", "glut", "oct", "ser", "da"] @@ -1122,7 +1134,7 @@ def get_connectivity( retry(client.materialize.live_query), table=client.materialize.synapse_table, timestamp=dt.datetime.utcnow(), - select_columns=columns, + select_columns=columns + sv_cols, ) elif filtered: has_view = "valid_connection_v2" in client.materialize.get_views( @@ -1169,20 +1181,23 @@ def get_connectivity( ): batch = ids[i : i + batch_size] if upstream: - if not filtered: + if not filtered or materialization == "live": filter_in_dict = dict(post_pt_root_id=batch) else: filter_in_dict = dict(synapses_nt_v1=dict(post_pt_root_id=batch)) syn.append(func(filter_in_dict=filter_in_dict)) if downstream: - if not filtered: + if not filtered or materialization == "live": filter_in_dict = dict(pre_pt_root_id=batch) else: filter_in_dict = dict(synapses_nt_v1=dict(pre_pt_root_id=batch)) syn.append(func(filter_in_dict=filter_in_dict)) - # Drop attrs to avoid issues when concatenating + # Some clean-up for df in syn: + # Drop supervoxel columns (if they exist) + df.drop(sv_cols, axis=1, errors="ignore", inplace=True) + # Drop `attrs`` to avoid issues when concatenating df.attrs = {} # Combine results from batches