Skip to content

Commit

Permalink
doctests: suppress some print statements that are hard to test for
Browse files Browse the repository at this point in the history
  • Loading branch information
schlegelp committed Jan 11, 2024
1 parent 7dfe926 commit 2e6e6e6
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 60 deletions.
6 changes: 6 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
import os

# This is to avoid some print statements in the code that are semi-random and
# hence will make doctests fails. I tried adding this environment variable to
# pytest.ini but that didn't work.
os.environ['FAFBSEG_TESTING'] = 'TRUE'

SKIP = ['conf.py']

Expand Down
28 changes: 13 additions & 15 deletions fafbseg/flywire/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def get_somas(x=None,
>>> from fafbseg import flywire
>>> somas = flywire.get_somas([720575940628842314])
Using materialization version 630.
>>> somas
>>> somas # doctest: +SKIP
id volume pt_supervoxel_id pt_root_id pt_position rad_est
0 5743218 27.935539 80645535832325071 720575940628842314 [584928, 201568, 22720] 2480.0
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def get_hierarchical_annotations(annotation_version=None,
root_col = "root_live"
to_update = ~segmentation.is_latest_root(table.root_live, progress=False)
if any(to_update):
if verbose:
if verbose and not os.environ.get('FAFBSEG_TESTING', False):
print(
"Updating root IDs for hierarchical annotations... ",
end="",
Expand All @@ -1344,7 +1344,7 @@ def get_hierarchical_annotations(annotation_version=None,
elif materialization:
root_col = f"root_{materialization}"
if root_col not in table.columns:
if verbose:
if verbose and not os.environ.get('FAFBSEG_TESTING', False):
print(
"Caching root IDs for hierarchical annotations at "
f"materialization '{materialization}'... ",
Expand All @@ -1364,7 +1364,7 @@ def get_hierarchical_annotations(annotation_version=None,
# so we don't have to do it again
if save:
table.to_csv(fp, index=False, sep='\t')
if verbose:
if verbose and not os.environ.get('FAFBSEG_TESTING', False):
print('Done.', flush=True)

# Make sure "root_id" corresponds to the requested materialization and drop
Expand Down Expand Up @@ -1508,7 +1508,6 @@ def search_community_annotations(x,
>>> an = flywire.search_community_annotations(720575940628857210)
Using materialization version 630.
Caching community annotations for materialization version "630"... Done.
>>> an.iloc[0]
id 46699
pt_position_x 419980
Expand All @@ -1524,14 +1523,13 @@ def search_community_annotations(x,
Search for all tags matching a given pattern:
>>> mi1 = flywire.search_community_annotations('Mi1')
Caching community annotations for materialization version "630"... Done.
>>> mi1.head()
id pos_x pos_y pos_z supervoxel_id root_id tag user user_id
0 61866 200890 58482 3724 84445998176507710 720575940626235644 Medulla Intrinsic - Mi1 TR77 2843
1 61867 192776 41653 4105 83881948711947639 720575940623948432 Medulla Intrinsic - Mi1 TR77 2843
2 61869 194704 97128 3905 84026397051849432 720575940632232418 Medulla Intrinsic - Mi1 TR77 2843
3 61871 191574 82521 2728 83814259892666307 720575940630475095 Medulla Intrinsic - Mi1 TR77 2843
4 61877 195454 43026 4031 84022754919681933 720575940617637204 Medulla Intrinsic - Mi1 TR77 2843
>>> mi1.head() # doctest: +SKIP
id pos_x pos_y ... tag user user_id
0 61866 200890 58482 ... Medulla Intrinsic - Mi1 TR77 2843
1 61867 192776 41653 ... Medulla Intrinsic - Mi1 TR77 2843
2 61869 194704 97128 ... Medulla Intrinsic - Mi1 TR77 2843
3 61871 191574 82521 ... Medulla Intrinsic - Mi1 TR77 2843
4 61877 195454 43026 ... Medulla Intrinsic - Mi1 TR77 2843
"""
# See if ``x`` is a root ID as string
Expand Down Expand Up @@ -1617,14 +1615,14 @@ def _get_community_annotation_table(dataset, materialization, split_positions=Fa
versions = get_cave_client(dataset=dataset).materialize.get_versions()
materialization = sorted(versions)[-1]

if verbose:
if verbose and not os.environ.get('FAFBSEG_TESTING', False):
print(f'Caching community annotations for materialization version "{materialization}"...',
end='', flush=True)
table = get_cave_table(table_name=COMMUNITY_ANNOTATION_TABLE,
dataset=dataset,
split_positions=split_positions,
materialization=materialization)
if verbose:
if verbose and not os.environ.get('FAFBSEG_TESTING', False):
print(' Done.')
return table

Expand Down
56 changes: 32 additions & 24 deletions fafbseg/flywire/l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None):
Examples
--------
>>> from fafbseg import flywire
>>> info = flywire.get_l2_info(720575940614131061) # doctest: +ELLIPSIS
>>> info
root_id l2_chunks chunks_missing area_um2 size_um3 length_um bounds_nm
0 720575940614131061 286 0 2378.16384 163.876526 60.666 [396816.0, 587808.0, 83968.0, 279072.0, 19560....
>>> info = flywire.get_l2_info(720575940614131061)
>>> info # doctest: +SKIP
root_id l2_chunks chunks_missing area_um2 size_um3 length_um ...
0 720575940614131061 286 0 2378.16384 163.876526 60.666 ...
"""
if navis.utils.is_iterable(root_ids):
Expand All @@ -85,11 +85,16 @@ def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None):
with ThreadPoolExecutor(max_workers=max_threads) as pool:
func = partial(get_l2_info, dataset=dataset)
futures = pool.map(func, root_ids)
info = [f for f in navis.config.tqdm(futures,
desc='Fetching L2 info',
total=len(root_ids),
disable=not progress or len(root_ids) == 1,
leave=False)]
info = [
f
for f in navis.config.tqdm(
futures,
desc="Fetching L2 info",
total=len(root_ids),
disable=not progress or len(root_ids) == 1,
leave=False,
)
]
return pd.concat(info, axis=0).reset_index(drop=True)

# Get/Initialize the CAVE client
Expand All @@ -98,48 +103,51 @@ def get_l2_info(root_ids, progress=True, max_threads=4, *, dataset=None):
get_l2_ids = partial(retry(client.chunkedgraph.get_leaves), stop_layer=2)
l2_ids = get_l2_ids(root_ids)

attributes = ['area_nm2', 'size_nm3', 'max_dt_nm', 'rep_coord_nm']
attributes = ["area_nm2", "size_nm3", "max_dt_nm", "rep_coord_nm"]
get_l2data = retry(client.l2cache.get_l2data)
info = get_l2data(l2_ids.tolist(), attributes=attributes)
n_miss = len([v for v in info.values() if not v])

row = [root_ids, len(l2_ids), n_miss]
info_df = pd.DataFrame([row],
columns=['root_id', 'l2_chunks', 'chunks_missing'])
info_df = pd.DataFrame([row], columns=["root_id", "l2_chunks", "chunks_missing"])

# Collect L2 attributes
for at in attributes:
if at in ('rep_coord_nm', ):
if at in ("rep_coord_nm",):
continue

summed = sum([v.get(at, 0) for v in info.values()])
if at.endswith('3'):
if at.endswith("3"):
summed /= 1000**3
elif at.endswith('2'):
elif at.endswith("2"):
summed /= 1000**2
else:
summed /= 1000

info_df[at.replace('_nm', '_um')] = [summed]
info_df[at.replace("_nm", "_um")] = [summed]

# Check bounding box
pts = np.array([v['rep_coord_nm'] for v in info.values() if v])
pts = np.array([v["rep_coord_nm"] for v in info.values() if v])

if len(pts) > 1:
bounds = [v for l in zip(pts.min(axis=0), pts.max(axis=0)) for v in l]
elif len(pts) == 1:
pt = pts[0]
rad = [v['max_dt_nm'] for v in info.values() if v][0] / 2
bounds = [pt[0] - rad, pt[0] + rad,
pt[1] - rad, pt[1] + rad,
pt[2] - rad, pt[2] + rad]
rad = [v["max_dt_nm"] for v in info.values() if v][0] / 2
bounds = [
pt[0] - rad,
pt[0] + rad,
pt[1] - rad,
pt[1] + rad,
pt[2] - rad,
pt[2] + rad,
]
bounds = [int(co) for co in bounds]
else:
bounds = None
info_df['bounds_nm'] = [bounds]
info_df["bounds_nm"] = [bounds]

info_df.rename({'max_dt_um': 'length_um'},
axis=1, inplace=True)
info_df.rename({"max_dt_um": "length_um"}, axis=1, inplace=True)

return info_df

Expand Down
52 changes: 34 additions & 18 deletions fafbseg/flywire/skeletonize.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6,
--------
>>> from fafbseg import flywire
>>> n = flywire.get_skeletons(720575940603231916)
>>> n
>>> n #doctest: +SKIP
type navis.TreeNeuron
name skeleton
id 720575940603231916
Expand All @@ -603,25 +603,31 @@ def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6,
n_branches 586
n_leafs 645
cable_length 2050971.75
soma [141, 458, 460, 462, 464, 466, 467, 469, 470, ...
soma None
units 1 nanometer
dtype: object
"""
if str(dataset) not in SKELETON_BASE_URL:
raise ValueError('Currently we only provide precomputed skeletons for the '
'630 and 783 data releases.')
raise ValueError(
"Currently we only provide precomputed skeletons for the "
"630 and 783 data releases."
)

if omit_failures not in (None, True, False):
raise ValueError('`omit_failures` must be either None, True or False. '
f'Got "{omit_failures}".')
raise ValueError(
"`omit_failures` must be either None, True or False. "
f'Got "{omit_failures}".'
)

if navis.utils.is_iterable(root_id):
root_id = np.asarray(root_id, dtype=np.int64)

il = is_latest_root(root_id, timestamp=f'mat_{dataset}')
il = is_latest_root(root_id, timestamp=f"mat_{dataset}")
if np.any(~il):
msg = (f'{(~il).sum()} root ID(s) did not exists at materialization {dataset}')
msg = (
f"{(~il).sum()} root ID(s) did not exists at materialization {dataset}"
)
if omit_failures is None:
raise ValueError(msg)
navis.config.logger.warning(msg)
Expand All @@ -630,17 +636,27 @@ def get_skeletons(root_id, threads=2, omit_failures=None, max_threads=6,
if (max_threads > 1) and (len(root_id) > 1):
with ThreadPoolExecutor(max_workers=max_threads) as pool:
futures = pool.map(get_skels, root_id)
nl = [f for f in navis.config.tqdm(futures,
desc='Fetching skeletons',
total=len(root_id),
disable=not progress or len(root_id) == 1,
leave=False)]
nl = [
f
for f in navis.config.tqdm(
futures,
desc="Fetching skeletons",
total=len(root_id),
disable=not progress or len(root_id) == 1,
leave=False,
)
]
else:
nl = [get_skels(r) for r in navis.config.tqdm(root_id,
desc='Fetching skeletons',
total=len(root_id),
disable=not progress or len(root_id) == 1,
leave=False)]
nl = [
get_skels(r)
for r in navis.config.tqdm(
root_id,
desc="Fetching skeletons",
total=len(root_id),
disable=not progress or len(root_id) == 1,
leave=False,
)
]

# Turn into neuron list
nl = navis.NeuronList(nl)
Expand Down
4 changes: 2 additions & 2 deletions fafbseg/flywire/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def get_synapses(
>>> from fafbseg import flywire
>>> syn = flywire.get_synapses(720575940603231916)
Using materialization version 630.
>>> syn.head()
>>> syn.head() #doctest: +SKIP
pre post cleft_score pre_x pre_y pre_z post_x post_y post_z id
0 720575940631406673 720575940603231916 60 434336 218108 28240 434340 218204 28240 3535576
1 720575940608044501 720575940603231916 136 429180 212316 51520 429244 212136 51520 15712693
Expand All @@ -436,7 +436,7 @@ def get_synapses(
>>> sk = flywire.get_skeletons(720575940603231916)
>>> _ = flywire.get_synapses(sk, attach=True)
Using materialization version 630.
>>> sk.connectors.head()
>>> sk.connectors.head() #doctest: +SKIP
connector_id x y z cleft_score partner_id type node_id
0 0 356304 146840 145120 145 720575940627592977 pre 217
1 1 344456 164324 162440 153 720575940537249676 pre 5
Expand Down
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
doctest_optionflags = IGNORE_EXCEPTION_DETAIL NUMBER NORMALIZE_WHITESPACE
addopts = --doctest-modules
env =
NAVIS_HEADLESS=TRUE
NAVIS_HEADLESS=TRUE
FAFBSEG_TESTING=TRUE

0 comments on commit 2e6e6e6

Please sign in to comment.