From 2e6e6e688c19f1dd5b86ca189ecf65bce894d319 Mon Sep 17 00:00:00 2001 From: Philipp Schlegel Date: Thu, 11 Jan 2024 13:35:50 +0000 Subject: [PATCH] doctests: suppress some print statements that are hard to test for --- conftest.py | 6 ++++ fafbseg/flywire/annotations.py | 28 ++++++++--------- fafbseg/flywire/l2.py | 56 +++++++++++++++++++--------------- fafbseg/flywire/skeletonize.py | 52 ++++++++++++++++++++----------- fafbseg/flywire/synapses.py | 4 +-- pytest.ini | 3 +- 6 files changed, 89 insertions(+), 60 deletions(-) diff --git a/conftest.py b/conftest.py index d0c6631..089f03b 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,9 @@ +import os + +# This is to avoid some print statements in the code that are semi-random and +# hence will make doctests fails. I tried adding this environment variable to +# pytest.ini but that didn't work. +os.environ['FAFBSEG_TESTING'] = 'TRUE' SKIP = ['conf.py'] diff --git a/fafbseg/flywire/annotations.py b/fafbseg/flywire/annotations.py index 855513f..1065e23 100644 --- a/fafbseg/flywire/annotations.py +++ b/fafbseg/flywire/annotations.py @@ -719,7 +719,7 @@ def get_somas(x=None, >>> from fafbseg import flywire >>> somas = flywire.get_somas([720575940628842314]) Using materialization version 630. - >>> somas + >>> somas # doctest: +SKIP id volume pt_supervoxel_id pt_root_id pt_position rad_est 0 5743218 27.935539 80645535832325071 720575940628842314 [584928, 201568, 22720] 2480.0 @@ -1329,7 +1329,7 @@ def get_hierarchical_annotations(annotation_version=None, root_col = "root_live" to_update = ~segmentation.is_latest_root(table.root_live, progress=False) if any(to_update): - if verbose: + if verbose and not os.environ.get('FAFBSEG_TESTING', False): print( "Updating root IDs for hierarchical annotations... ", end="", @@ -1344,7 +1344,7 @@ def get_hierarchical_annotations(annotation_version=None, elif materialization: root_col = f"root_{materialization}" if root_col not in table.columns: - if verbose: + if verbose and not os.environ.get('FAFBSEG_TESTING', False): print( "Caching root IDs for hierarchical annotations at " f"materialization '{materialization}'... ", @@ -1364,7 +1364,7 @@ def get_hierarchical_annotations(annotation_version=None, # so we don't have to do it again if save: table.to_csv(fp, index=False, sep='\t') - if verbose: + if verbose and not os.environ.get('FAFBSEG_TESTING', False): print('Done.', flush=True) # Make sure "root_id" corresponds to the requested materialization and drop @@ -1508,7 +1508,6 @@ def search_community_annotations(x, >>> an = flywire.search_community_annotations(720575940628857210) Using materialization version 630. - Caching community annotations for materialization version "630"... Done. >>> an.iloc[0] id 46699 pt_position_x 419980 @@ -1524,14 +1523,13 @@ def search_community_annotations(x, Search for all tags matching a given pattern: >>> mi1 = flywire.search_community_annotations('Mi1') - Caching community annotations for materialization version "630"... Done. - >>> mi1.head() - id pos_x pos_y pos_z supervoxel_id root_id tag user user_id - 0 61866 200890 58482 3724 84445998176507710 720575940626235644 Medulla Intrinsic - Mi1 TR77 2843 - 1 61867 192776 41653 4105 83881948711947639 720575940623948432 Medulla Intrinsic - Mi1 TR77 2843 - 2 61869 194704 97128 3905 84026397051849432 720575940632232418 Medulla Intrinsic - Mi1 TR77 2843 - 3 61871 191574 82521 2728 83814259892666307 720575940630475095 Medulla Intrinsic - Mi1 TR77 2843 - 4 61877 195454 43026 4031 84022754919681933 720575940617637204 Medulla Intrinsic - Mi1 TR77 2843 + >>> mi1.head() # doctest: +SKIP + id pos_x pos_y ... tag user user_id + 0 61866 200890 58482 ... Medulla Intrinsic - Mi1 TR77 2843 + 1 61867 192776 41653 ... Medulla Intrinsic - Mi1 TR77 2843 + 2 61869 194704 97128 ... Medulla Intrinsic - Mi1 TR77 2843 + 3 61871 191574 82521 ... Medulla Intrinsic - Mi1 TR77 2843 + 4 61877 195454 43026 ... Medulla Intrinsic - Mi1 TR77 2843 """ # See if ``x`` is a root ID as string @@ -1617,14 +1615,14 @@ def _get_community_annotation_table(dataset, materialization, split_positions=Fa versions = get_cave_client(dataset=dataset).materialize.get_versions() materialization = sorted(versions)[-1] - if verbose: + if verbose and not os.environ.get('FAFBSEG_TESTING', False): print(f'Caching community annotations for materialization version "{materialization}"...', end='', flush=True) table = get_cave_table(table_name=COMMUNITY_ANNOTATION_TABLE, dataset=dataset, split_positions=split_positions, materialization=materialization) - if verbose: + if verbose and not os.environ.get('FAFBSEG_TESTING', False): print(' Done.') return table diff --git a/fafbseg/flywire/l2.py b/fafbseg/flywire/l2.py index 0f94114..02fb27a 100644 --- a/fafbseg/flywire/l2.py +++ b/fafbseg/flywire/l2.py @@ -73,10 +73,10 @@ def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None): Examples -------- >>> from fafbseg import flywire - >>> info = flywire.get_l2_info(720575940614131061) # doctest: +ELLIPSIS - >>> info - root_id l2_chunks chunks_missing area_um2 size_um3 length_um bounds_nm - 0 720575940614131061 286 0 2378.16384 163.876526 60.666 [396816.0, 587808.0, 83968.0, 279072.0, 19560.... + >>> info = flywire.get_l2_info(720575940614131061) + >>> info # doctest: +SKIP + root_id l2_chunks chunks_missing area_um2 size_um3 length_um ... + 0 720575940614131061 286 0 2378.16384 163.876526 60.666 ... """ if navis.utils.is_iterable(root_ids): @@ -85,11 +85,16 @@ def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None): with ThreadPoolExecutor(max_workers=max_threads) as pool: func = partial(get_l2_info, dataset=dataset) futures = pool.map(func, root_ids) - info = [f for f in navis.config.tqdm(futures, - desc='Fetching L2 info', - total=len(root_ids), - disable=not progress or len(root_ids) == 1, - leave=False)] + info = [ + f + for f in navis.config.tqdm( + futures, + desc="Fetching L2 info", + total=len(root_ids), + disable=not progress or len(root_ids) == 1, + leave=False, + ) + ] return pd.concat(info, axis=0).reset_index(drop=True) # Get/Initialize the CAVE client @@ -98,48 +103,51 @@ def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None): get_l2_ids = partial(retry(client.chunkedgraph.get_leaves), stop_layer=2) l2_ids = get_l2_ids(root_ids) - attributes = ['area_nm2', 'size_nm3', 'max_dt_nm', 'rep_coord_nm'] + attributes = ["area_nm2", "size_nm3", "max_dt_nm", "rep_coord_nm"] get_l2data = retry(client.l2cache.get_l2data) info = get_l2data(l2_ids.tolist(), attributes=attributes) n_miss = len([v for v in info.values() if not v]) row = [root_ids, len(l2_ids), n_miss] - info_df = pd.DataFrame([row], - columns=['root_id', 'l2_chunks', 'chunks_missing']) + info_df = pd.DataFrame([row], columns=["root_id", "l2_chunks", "chunks_missing"]) # Collect L2 attributes for at in attributes: - if at in ('rep_coord_nm', ): + if at in ("rep_coord_nm",): continue summed = sum([v.get(at, 0) for v in info.values()]) - if at.endswith('3'): + if at.endswith("3"): summed /= 1000**3 - elif at.endswith('2'): + elif at.endswith("2"): summed /= 1000**2 else: summed /= 1000 - info_df[at.replace('_nm', '_um')] = [summed] + info_df[at.replace("_nm", "_um")] = [summed] # Check bounding box - pts = np.array([v['rep_coord_nm'] for v in info.values() if v]) + pts = np.array([v["rep_coord_nm"] for v in info.values() if v]) if len(pts) > 1: bounds = [v for l in zip(pts.min(axis=0), pts.max(axis=0)) for v in l] elif len(pts) == 1: pt = pts[0] - rad = [v['max_dt_nm'] for v in info.values() if v][0] / 2 - bounds = [pt[0] - rad, pt[0] + rad, - pt[1] - rad, pt[1] + rad, - pt[2] - rad, pt[2] + rad] + rad = [v["max_dt_nm"] for v in info.values() if v][0] / 2 + bounds = [ + pt[0] - rad, + pt[0] + rad, + pt[1] - rad, + pt[1] + rad, + pt[2] - rad, + pt[2] + rad, + ] bounds = [int(co) for co in bounds] else: bounds = None - info_df['bounds_nm'] = [bounds] + info_df["bounds_nm"] = [bounds] - info_df.rename({'max_dt_um': 'length_um'}, - axis=1, inplace=True) + info_df.rename({"max_dt_um": "length_um"}, axis=1, inplace=True) return info_df diff --git a/fafbseg/flywire/skeletonize.py b/fafbseg/flywire/skeletonize.py index 392c16d..8025715 100644 --- a/fafbseg/flywire/skeletonize.py +++ b/fafbseg/flywire/skeletonize.py @@ -594,7 +594,7 @@ def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6, -------- >>> from fafbseg import flywire >>> n = flywire.get_skeletons(720575940603231916) - >>> n + >>> n #doctest: +SKIP type navis.TreeNeuron name skeleton id 720575940603231916 @@ -603,25 +603,31 @@ def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6, n_branches 586 n_leafs 645 cable_length 2050971.75 - soma [141, 458, 460, 462, 464, 466, 467, 469, 470, ... + soma None units 1 nanometer dtype: object """ if str(dataset) not in SKELETON_BASE_URL: - raise ValueError('Currently we only provide precomputed skeletons for the ' - '630 and 783 data releases.') + raise ValueError( + "Currently we only provide precomputed skeletons for the " + "630 and 783 data releases." + ) if omit_failures not in (None, True, False): - raise ValueError('`omit_failures` must be either None, True or False. ' - f'Got "{omit_failures}".') + raise ValueError( + "`omit_failures` must be either None, True or False. " + f'Got "{omit_failures}".' + ) if navis.utils.is_iterable(root_id): root_id = np.asarray(root_id, dtype=np.int64) - il = is_latest_root(root_id, timestamp=f'mat_{dataset}') + il = is_latest_root(root_id, timestamp=f"mat_{dataset}") if np.any(~il): - msg = (f'{(~il).sum()} root ID(s) did not exists at materialization {dataset}') + msg = ( + f"{(~il).sum()} root ID(s) did not exists at materialization {dataset}" + ) if omit_failures is None: raise ValueError(msg) navis.config.logger.warning(msg) @@ -630,17 +636,27 @@ def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6, if (max_threads > 1) and (len(root_id) > 1): with ThreadPoolExecutor(max_workers=max_threads) as pool: futures = pool.map(get_skels, root_id) - nl = [f for f in navis.config.tqdm(futures, - desc='Fetching skeletons', - total=len(root_id), - disable=not progress or len(root_id) == 1, - leave=False)] + nl = [ + f + for f in navis.config.tqdm( + futures, + desc="Fetching skeletons", + total=len(root_id), + disable=not progress or len(root_id) == 1, + leave=False, + ) + ] else: - nl = [get_skels(r) for r in navis.config.tqdm(root_id, - desc='Fetching skeletons', - total=len(root_id), - disable=not progress or len(root_id) == 1, - leave=False)] + nl = [ + get_skels(r) + for r in navis.config.tqdm( + root_id, + desc="Fetching skeletons", + total=len(root_id), + disable=not progress or len(root_id) == 1, + leave=False, + ) + ] # Turn into neuron list nl = navis.NeuronList(nl) diff --git a/fafbseg/flywire/synapses.py b/fafbseg/flywire/synapses.py index 963787b..ea30db3 100644 --- a/fafbseg/flywire/synapses.py +++ b/fafbseg/flywire/synapses.py @@ -422,7 +422,7 @@ def get_synapses( >>> from fafbseg import flywire >>> syn = flywire.get_synapses(720575940603231916) Using materialization version 630. - >>> syn.head() + >>> syn.head() #doctest: +SKIP pre post cleft_score pre_x pre_y pre_z post_x post_y post_z id 0 720575940631406673 720575940603231916 60 434336 218108 28240 434340 218204 28240 3535576 1 720575940608044501 720575940603231916 136 429180 212316 51520 429244 212136 51520 15712693 @@ -436,7 +436,7 @@ def get_synapses( >>> sk = flywire.get_skeletons(720575940603231916) >>> _ = flywire.get_synapses(sk, attach=True) Using materialization version 630. - >>> sk.connectors.head() + >>> sk.connectors.head() #doctest: +SKIP connector_id x y z cleft_score partner_id type node_id 0 0 356304 146840 145120 145 720575940627592977 pre 217 1 1 344456 164324 162440 153 720575940537249676 pre 5 diff --git a/pytest.ini b/pytest.ini index 93ee4ed..1127198 100644 --- a/pytest.ini +++ b/pytest.ini @@ -2,4 +2,5 @@ doctest_optionflags = IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE addopts = --doctest-modules env = - NAVIS_HEADLESS=TRUE \ No newline at end of file + NAVIS_HEADLESS=TRUE + FAFBSEG_TESTING=TRUE \ No newline at end of file