Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Google storage: allow optional use of thread pool #2

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions fafbseg/google/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,16 @@ class GSPointLoader(object):
`Peter Li<https://gist.github.com/chinasaur/5429ef3e0a60aa7a1c38801b0cbfe9bb>_.
"""

def __init__(self, cloud_volume):
def __init__(self, cloud_volume, use_threads=False):
"""Initialize with zero points.

See add_points to queue some.

Parameters
----------
cloud_volume : cloudvolume.CloudVolume
use_threads : bool, optional
If true, a thread pool is used rather than a process pool.

"""
if not isinstance(cloud_volume, CVtype):
Expand All @@ -54,6 +56,7 @@ def __init__(self, cloud_volume):
self._volume = cloud_volume
self._chunk_map = collections.defaultdict(set)
self._points = None
self._use_threads = use_threads

def add_points(self, points):
"""Add more points to be loaded.
Expand Down Expand Up @@ -127,11 +130,16 @@ def load_all(self, max_workers=4, return_sorted=True, progress=True):
"""
progress_state = self._volume.progress
self._volume.progress = False
if self._use_threads:
Executor = futures.ThreadPoolExecutor
else:
Executor = futuses.ProcessPoolExecutor
with tqdm.tqdm(total=len(self._chunk_map),
desc='Segmentation IDs',
leave=False,
disable=not progress) as pbar:
with futures.ProcessPoolExecutor(max_workers=max_workers) as ex:

with Executor(max_workers=max_workers) as ex:
point_futures = [ex.submit(self._load_points, k) for k in self._chunk_map]
for f in futures.as_completed(point_futures):
pbar.update(1)
Expand All @@ -152,7 +160,7 @@ def load_all(self, max_workers=4, return_sorted=True, progress=True):
return points, data


def _get_seg_ids_gs(points, volume, max_workers=4, progress=True):
def _get_seg_ids_gs(points, volume, max_workers=4, progress=True, use_threads=False):
"""Fetch segment IDs using CloudVolume hosted on Google Storage.

This is the default option as it does not require any credentials. Downside:
Expand All @@ -174,7 +182,7 @@ def _get_seg_ids_gs(points, volume, max_workers=4, progress=True):
list : List of segmentation IDs at given locations.

"""
pl = GSPointLoader(volume)
pl = GSPointLoader(volume, use_threads=use_threads)
pl.add_points(points)

points, data = pl.load_all(max_workers=max_workers,
Expand Down Expand Up @@ -335,9 +343,12 @@ def use_google_storage(volume, max_workers=8, progress=True, **kwargs):

volume = cloudvolume.CloudVolume(url, **defaults)

use_threads = kwargs.get('use_threads', False)

_get_seg_ids = lambda x: _get_seg_ids_gs(x, volume,
max_workers=max_workers,
progress=progress)
progress=progress,
use_threads=use_threads)
print('Using Google CloudStorage to retrieve segmentation IDs.')


Expand Down