Skip to content

Commit

Permalink
refactor tensor indexing, modify context either (#1011)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored Feb 22, 2020
1 parent 54f6631 commit 900cbea
Show file tree
Hide file tree
Showing 7 changed files with 1,040 additions and 484 deletions.
47 changes: 18 additions & 29 deletions mars/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,12 @@ def get_chunks_data(self, worker: str, chunk_keys: List[str], indexes: List=None
compression_types=compression_types)

# Fetch tileable data by tileable keys and indexes.
def get_tileable_data(self, tileable_key: str, indexes: List=None,
compression_types: List[str]=None):
def get_tileable_data(self, tileable_key: str, indexes: List = None,
compression_types: List[str] = None):
from .serialize import dataserializer
from .utils import merge_chunks
from .tensor.core import TENSOR_TYPE
from .tensor.datasource import empty
from .tensor.indexing.getitem import TensorIndexTilesHandler, calc_pos
from .tensor.indexing.index_lib import NDArrayIndexesHandler

nsplits, chunk_keys, chunk_indexes = self.get_tileable_metas([tileable_key])[0]
chunk_idx_to_keys = dict(zip(chunk_indexes, chunk_keys))
Expand All @@ -326,37 +325,33 @@ def get_tileable_data(self, tileable_key: str, indexes: List=None,
[chunk_workers[e].append(chunk_key) for chunk_key, e in chunk_keys_to_worker.items()]

chunk_results = dict()
select_pos = None

if not indexes:
if indexes is None or len(indexes) == 0:
datas = []
for endpoint, chunks in chunk_workers.items():
datas.append(self.get_chunks_data(endpoint, chunks, compression_types=compression_types))
datas = [d.result() for d in datas]
for (endpoint, chunks), d in zip(chunk_workers.items(), datas):
d = [dataserializer.loads(db) for db in d]
chunk_results.update(dict(zip([chunk_keys_to_idx[k] for k in chunks], d)))

chunk_results = [(k, v) for k, v in chunk_results.items()]
if len(chunk_results) == 1:
return chunk_results[0][1]
else:
return merge_chunks(chunk_results)
else:
# TODO: make a common util to handle indexes
if any(isinstance(ind, TENSOR_TYPE) for ind in indexes):
raise TypeError("Doesn't support indexing by tensors")
# Reuse the getitem logic to get each chunk's indexes
tileable_shape = tuple(sum(s) for s in nsplits)
empty_tileable = empty(tileable_shape, chunk_size=nsplits)._inplace_tile()
indexed = empty_tileable[tuple(indexes)]
index_handler = TensorIndexTilesHandler(indexed.op)
index_handler._extract_indexes_info()
index_handler._preprocess_fancy_indexes()
index_handler._process_fancy_indexes()
index_handler._process_in_tensor()

# Select by order
if len(index_handler._fancy_index_infos) != 0:
index_shape = index_handler._fancy_index_info.chunk_unified_fancy_indexes[0].shape
select_pos = calc_pos(index_shape, index_handler._fancy_index_info.chunk_index_to_pos)
indexes_handler = NDArrayIndexesHandler()
try:
context = indexes_handler.handle(indexed.op, return_context=True)
except TypeError:
raise TypeError("Doesn't support indexing by tensors")

result_chunks = dict()
for c in index_handler._out_chunks:
for c in context.processed_chunks:
result_chunks[chunk_idx_to_keys[c.inputs[0].index]] = [c.index, c.op.indexes]

chunk_datas = dict()
Expand All @@ -378,14 +373,8 @@ def get_tileable_data(self, tileable_key: str, indexes: List=None,
d = [dataserializer.loads(db) for db in d]
chunk_results.update(dict(zip(idx, d)))

chunk_results = [(k, v) for k, v in chunk_results.items()]
if len(chunk_results) == 1:
ret = chunk_results[0][1]
else:
ret = merge_chunks(chunk_results)
if select_pos is not None:
ret = ret[select_pos]
return ret
chunk_results = [(k, v) for k, v in chunk_results.items()]
return indexes_handler.aggregate_result(context, chunk_results)

def create_lock(self):
return self._actor_ctx.lock()
Expand Down
2 changes: 1 addition & 1 deletion mars/tensor/base/argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def argsort(a, axis=-1, kind=None, parallel_kind=None, psrs_kinds=None, order=No
a, axis, kind, parallel_kind, psrs_kinds, order = _validate_sort_arguments(
a, axis, kind, parallel_kind, psrs_kinds, order)

op = TensorSort(axis=axis ,kind=kind, parallel_kind=parallel_kind,
op = TensorSort(axis=axis, kind=kind, parallel_kind=parallel_kind,
order=order, psrs_kinds=psrs_kinds,
return_value=False, return_indices=True,
dtype=a.dtype, gpu=a.op.gpu)
Expand Down
1 change: 0 additions & 1 deletion mars/tensor/indexing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def preprocess_index(index):
raise IndexError(_INDEX_ERROR_MSG)
if ind.dtype.kind == 'b':
# bool indexing
ind = astensor(ind)
has_bool_index = True
else:
# fancy indexing
Expand Down
Loading

0 comments on commit 900cbea

Please sign in to comment.