diff --git a/fafbseg/flywire/segmentation.py b/fafbseg/flywire/segmentation.py index ca0a8b7..9df4785 100644 --- a/fafbseg/flywire/segmentation.py +++ b/fafbseg/flywire/segmentation.py @@ -1559,26 +1559,27 @@ def is_valid_root(x, raise_exc=False, *, dataset=None): """ client = get_cave_client(dataset=dataset) - vol = cv.CloudVolume(client.chunkedgraph.cloudvolume_path, use_https=True, progress=False) + vol = get_cloudvolume(client.chunkedgraph.cloudvolume_path) + + def _is_valid(x, raise_exc): + try: + is_valid = vol.get_chunk_layer(x) == vol.info["graph"]["n_layers"] + except ValueError: + is_valid = False + + if raise_exc and not is_valid: + raise ValueError(f"{x} is not a valid root ID") + + return is_valid if navis.utils.is_iterable(x): - is_valid = np.array([is_valid_root(r, dataset=dataset) for r in x]) + is_valid = np.array([_is_valid(r, raise_exc=False) for r in x]) if raise_exc and not all(is_valid): invalid = set(np.asarray(x)[~is_valid].tolist()) raise ValueError(f"Invalid root IDs found: {invalid}") return is_valid - - try: - # Note: FlyWire has 10 layers but FANC has 9 - # by using the volume's info we stay flexible - is_valid = vol.get_chunk_layer(x) == vol.info['graph']['n_layers'] - except ValueError: - is_valid = False - - if raise_exc and not is_valid: - raise ValueError(f"{x} is not a valid root ID") - - return is_valid + else: + return _is_valid(x, raise_exc=raise_exc) @inject_dataset(disallowed=["flat_630", "flat_571"]) @@ -1610,24 +1611,28 @@ def is_valid_supervoxel(x, raise_exc=False, *, dataset=None): Use this function to check if a root ID is valid. """ - vol = get_cloudvolume(dataset) + client = get_cave_client(dataset=dataset) + vol = get_cloudvolume(client.chunkedgraph.cloudvolume_path) + + def _is_valid(x, raise_exc): + try: + is_valid = vol.get_chunk_layer(x) == 1 + except ValueError: + is_valid = False + + if raise_exc and not is_valid: + raise ValueError(f"{x} is not a valid supervoxel ID") + + return is_valid if navis.utils.is_iterable(x): - is_valid = np.array([is_valid_supervoxel(r, dataset=vol) for r in x]) + is_valid = np.array([_is_valid(r, raise_exc=False) for r in x]) if raise_exc and not all(is_valid): invalid = set(np.asarray(x)[~is_valid].tolist()) raise ValueError(f"Invalid supervoxel IDs found: {invalid}") return is_valid - - try: - is_valid = vol.get_chunk_layer(x) == 1 - except BaseException: - is_valid = False - - if raise_exc and not is_valid: - raise ValueError(f"{x} is not a valid supervoxel ID") - - return is_valid + else: + return _is_valid(x, raise_exc=raise_exc) @inject_dataset(disallowed=["flat_630", "flat_571"])