Skip to content

Commit

Permalink
[BACKPORT] Refactor tensor indexing (#1011) (#1012)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored Feb 22, 2020
1 parent 928a28e commit 30f33a2
Show file tree
Hide file tree
Showing 7 changed files with 953 additions and 453 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ jobs:
run: |
source ./.github/workflows/reload-env.sh
export DEFAULT_VENV=$VIRTUAL_ENV
if [[ "$PYTHON" =~ "2.7" ]]; then
pip install enum34\<1.1.8
fi
if [[ "$PYTHON" =~ "3.8" ]]; then
conda install -n test --quiet --yes python=$PYTHON numpy pyarrow
# remove three lines below once gevent 1.5 is released
Expand Down
2 changes: 1 addition & 1 deletion mars/tensor/base/argsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def argsort(a, axis=-1, kind=None, parallel_kind=None, psrs_kinds=None, order=No
a, axis, kind, parallel_kind, psrs_kinds, order = _validate_sort_arguments(
a, axis, kind, parallel_kind, psrs_kinds, order)

op = TensorSort(axis=axis ,kind=kind, parallel_kind=parallel_kind,
op = TensorSort(axis=axis, kind=kind, parallel_kind=parallel_kind,
order=order, psrs_kinds=psrs_kinds,
return_value=False, return_indices=True,
dtype=a.dtype, gpu=a.op.gpu)
Expand Down
1 change: 0 additions & 1 deletion mars/tensor/indexing/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def preprocess_index(index):
raise IndexError(_INDEX_ERROR_MSG)
if ind.dtype.kind == 'b':
# bool indexing
ind = astensor(ind)
has_bool_index = True
else:
# fancy indexing
Expand Down
417 changes: 14 additions & 403 deletions mars/tensor/indexing/getitem.py

Large diffs are not rendered by default.

Loading

0 comments on commit 30f33a2

Please sign in to comment.