diff --git a/mars/tensor/arithmetic/add.py b/mars/tensor/arithmetic/add.py index 8c23399dca..327ee1cec2 100644 --- a/mars/tensor/arithmetic/add.py +++ b/mars/tensor/arithmetic/add.py @@ -18,6 +18,7 @@ from functools import reduce from ... import opcodes as OperandDef +from ...serialization.serializables import BoolField from ..array_utils import device, as_same_device from ..datasource import scalar from ..utils import infer_dtype @@ -89,6 +90,8 @@ class TensorTreeAdd(TensorMultiOp): _op_type_ = OperandDef.TREE_ADD _func_name = "add" + ignore_empty_input = BoolField("ignore_empty_input", default=False) + @classmethod def _is_sparse(cls, *args): if args and all(hasattr(x, "issparse") and x.issparse() for x in args): @@ -96,10 +99,12 @@ def _is_sparse(cls, *args): return False @classmethod - def execute(cls, ctx, op): + def execute(cls, ctx, op: "TensorTreeAdd"): inputs, device_id, xp = as_same_device( [ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True ) + if op.ignore_empty_input: + inputs = [inp for inp in inputs if not hasattr(inp, "size") or inp.size > 0] with device(device_id): ctx[op.outputs[0].key] = reduce(xp.add, inputs) diff --git a/mars/tensor/arithmetic/multiply.py b/mars/tensor/arithmetic/multiply.py index cbc85aa36f..54bc462d9d 100644 --- a/mars/tensor/arithmetic/multiply.py +++ b/mars/tensor/arithmetic/multiply.py @@ -18,6 +18,7 @@ from functools import reduce from ... import opcodes as OperandDef +from ...serialization.serializables import BoolField from ..array_utils import device, as_same_device from ..datasource import scalar from ..utils import infer_dtype @@ -88,6 +89,8 @@ class TensorTreeMultiply(TensorMultiOp): _op_type_ = OperandDef.TREE_MULTIPLY _func_name = "multiply" + ignore_empty_input = BoolField("ignore_empty_input", default=False) + def __init__(self, sparse=False, **kw): super().__init__(sparse=sparse, **kw) @@ -106,6 +109,8 @@ def execute(cls, ctx, op): inputs, device_id, xp = as_same_device( [ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True ) + if op.ignore_empty_input: + inputs = [inp for inp in inputs if not hasattr(inp, "size") or inp.size > 0] with device(device_id): ctx[op.outputs[0].key] = reduce(xp.multiply, inputs) diff --git a/mars/tensor/reduction/core.py b/mars/tensor/reduction/core.py index a3f6824a26..2ecd335dbb 100644 --- a/mars/tensor/reduction/core.py +++ b/mars/tensor/reduction/core.py @@ -586,7 +586,11 @@ def tile(cls, op): to_cum_chunks.append(sliced_chunk) to_cum_chunks.append(chunk) - bin_op = bin_op_type(args=to_cum_chunks, dtype=chunk.dtype) + # GH#3132: some chunks of to_cum_chunks may be empty, + # so we tell tree_add&tree_multiply to ignore them + bin_op = bin_op_type( + args=to_cum_chunks, dtype=chunk.dtype, ignore_empty_input=True + ) output_chunk = bin_op.new_chunk( to_cum_chunks, shape=chunk.shape, diff --git a/mars/tensor/reduction/tests/test_reduction_execution.py b/mars/tensor/reduction/tests/test_reduction_execution.py index 340d5c551f..0059092e69 100644 --- a/mars/tensor/reduction/tests/test_reduction_execution.py +++ b/mars/tensor/reduction/tests/test_reduction_execution.py @@ -497,6 +497,16 @@ def test_cum_reduction(setup): np.cumsum(np.array(list("abcdefghi"), dtype=object)), ) + # test empty chunks + raw = np.random.rand(100) + arr = tensor(raw, chunk_size=((0, 100),)) + res = arr.cumsum().execute().fetch() + expected = raw.cumsum() + np.testing.assert_allclose(res, expected) + res = arr.cumprod().execute().fetch() + expected = raw.cumprod() + np.testing.assert_allclose(res, expected) + def test_nan_cum_reduction(setup): raw = np.random.randint(5, size=(8, 8, 8)).astype(float) diff --git a/mars/tensor/reshape/reshape.py b/mars/tensor/reshape/reshape.py index edee28d5e8..a921fc65f6 100644 --- a/mars/tensor/reshape/reshape.py +++ b/mars/tensor/reshape/reshape.py @@ -603,7 +603,9 @@ def reshape(a, newshape, order="C"): tensor_order = get_order(order, a.order, available_options="CFA") - if a.shape == newshape and tensor_order == a.order: + if a.shape == newshape and ( + a.ndim <= 1 or (a.ndim > 1 and tensor_order == a.order) + ): # does not need to reshape return a return _reshape(