Skip to content

Commit

Permalink
[BACKPORT] Add preliminary remote function support (#1238) (#1239)
Browse files Browse the repository at this point in the history
  • Loading branch information
Xuye (Chris) Qin authored May 23, 2020
1 parent d3cca24 commit 5fecd78
Show file tree
Hide file tree
Showing 17 changed files with 361 additions and 15 deletions.
2 changes: 1 addition & 1 deletion mars/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
import os

version_info = (0, 4, 0, 'rc1')
version_info = (0, 4, 0)
_num_index = max(idx if isinstance(v, int) else 0
for idx, v in enumerate(version_info))
__version__ = '.'.join(map(str, version_info[:_num_index + 1])) + \
Expand Down
23 changes: 23 additions & 0 deletions mars/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def h(*args, **kwargs):
return func(*args, **kwargs)
return h

def get_current_session(self):
"""
Get current session.
:return: Session
"""
raise NotImplementedError

# ---------------
# Meta relative
# ---------------
Expand Down Expand Up @@ -185,6 +193,13 @@ def copy(self):
new_d.update(self)
return new_d

def get_current_session(self):
from .session import new_session

sess = new_session()
sess._sess = self._local_session
return sess

def set_ncores(self, ncores):
self._ncores = ncores

Expand Down Expand Up @@ -280,6 +295,14 @@ def running_mode(self):
def session_id(self):
return self._session_id

def get_current_session(self):
from .session import new_session, ClusterSession

sess = new_session()
sess._sess = ClusterSession(self._scheduler_address,
session_id=self._session_id)
return sess

def get_scheduler_addresses(self):
return self._cluster_info.get_schedulers()

Expand Down
4 changes: 2 additions & 2 deletions mars/dataframe/base/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from ... import opcodes
from ...config import options
from ...serialize import AnyField, BoolField, TupleField, DictField, FunctionField
from ...serialize import AnyField, BoolField, TupleField, DictField
from ..core import DATAFRAME_CHUNK_TYPE, DATAFRAME_TYPE
from ..operands import DataFrameOperandMixin, DataFrameOperand, ObjectType
from ..utils import build_empty_df, build_empty_series, validate_axis, parse_index
Expand All @@ -26,7 +26,7 @@
class TransformOperand(DataFrameOperand, DataFrameOperandMixin):
_op_type_ = opcodes.TRANSFORM

_func = FunctionField('func')
_func = AnyField('func')
_axis = AnyField('axis')
_convert_dtype = BoolField('convert_dtype')
_args = TupleField('args')
Expand Down
4 changes: 2 additions & 2 deletions mars/dataframe/groupby/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pandas as pd

from ... import opcodes
from ...serialize import BoolField, TupleField, DictField, FunctionField
from ...serialize import BoolField, TupleField, DictField, AnyField
from ..operands import DataFrameOperandMixin, DataFrameOperand, ObjectType
from ..utils import build_empty_df, build_empty_series, parse_index

Expand All @@ -25,7 +25,7 @@ class GroupByTransform(DataFrameOperand, DataFrameOperandMixin):
_op_type_ = opcodes.TRANSFORM
_op_module_ = 'dataframe.groupby'

_func = FunctionField('func')
_func = AnyField('func')
_args = TupleField('args')
_kwds = DictField('kwds')

Expand Down
19 changes: 19 additions & 0 deletions mars/deploy/local/tests/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from mars import tensor as mt
from mars import dataframe as md
from mars import remote as mr
from mars.tensor.operands import TensorOperand
from mars.tensor.arithmetic.core import TensorElementWise
from mars.tensor.arithmetic.abs import TensorAbs
Expand Down Expand Up @@ -940,3 +941,21 @@ def testStoreHDF5ForLocalCluster(self):
with h5py.File(filename, 'r') as f:
result = np.asarray(f[dataset])
np.testing.assert_array_equal(result, raw)

def testRemoteFunctionInLocalCluster(self):
with new_cluster(scheduler_n_process=2, worker_n_process=2,
shared_memory='20M', modules=[__name__], web=True) as cluster:
session = cluster.session

def f(x):
return x + 1

def g(x, y):
return x * y

a = mr.spawn(f, 3)
b = mr.spawn(f, 4)
c = mr.spawn(g, (a, b))

r = session.run(c, timeout=_exec_timeout)
self.assertEqual(r, 20)
3 changes: 2 additions & 1 deletion mars/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,5 +973,6 @@ def register_default(op_cls):
from . import dataframe
from . import optimizes
from . import learn
from . import remote

del tensor, dataframe, optimizes, learn
del tensor, dataframe, optimizes, learn, remote
4 changes: 3 additions & 1 deletion mars/operands.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,9 @@ def _create_chunk(self, output_idx, index, **kw):
return ObjectChunk(data)

def _create_tileable(self, output_idx, **kw):
data = ObjectData(op=self, i=output_idx, **kw)
if 'i' not in kw:
kw['i'] = output_idx
data = ObjectData(op=self, **kw)
return Object(data)

def get_fetch_op_cls(self, obj):
Expand Down
15 changes: 15 additions & 0 deletions mars/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .core import spawn
145 changes: 145 additions & 0 deletions mars/remote/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial

from .. import opcodes
from ..core import Entity, Base
from ..serialize import FunctionField, ListField, DictField
from ..operands import ObjectOperand, ObjectOperandMixin
from ..tensor.core import TENSOR_TYPE
from ..dataframe.core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE
from .utils import replace_inputs, find_objects


class RemoteFunction(ObjectOperand, ObjectOperandMixin):
_op_type_ = opcodes.REMOTE_FUNCATION
_op_module_ = 'remote'

_function = FunctionField('function')
_function_args = ListField('function_args')
_function_kwargs = DictField('function_kwargs')

def __init__(self, function=None, function_args=None,
function_kwargs=None, **kw):
super().__init__(_function=function, _function_args=function_args,
_function_kwargs=function_kwargs, **kw)

@property
def function(self):
return self._function

@property
def function_args(self):
return self._function_args

@property
def function_kwargs(self):
return self._function_kwargs

@classmethod
def _no_prepare(cls, tileable):
return isinstance(tileable, (TENSOR_TYPE, DATAFRAME_TYPE,
SERIES_TYPE, INDEX_TYPE))

def _set_inputs(self, inputs):
raw_inputs = getattr(self, '_inputs', None)
super()._set_inputs(inputs)

function_inputs = iter(inp for inp in self._inputs
if isinstance(inp.op, RemoteFunction))
mapping = {inp: new_inp for inp, new_inp in zip(inputs, self._inputs)}
if raw_inputs is not None:
for raw_inp in raw_inputs:
if self._no_prepare(raw_inp): # pragma: no cover
raise NotImplementedError
else:
mapping[raw_inp] = next(function_inputs)
self._function_args = replace_inputs(self._function_args, mapping)
self._function_kwargs = replace_inputs(self._function_kwargs, mapping)

def __call__(self):
find_inputs = partial(find_objects, types=(Entity, Base))
inputs = find_inputs(self._function_args) + find_inputs(self._function_kwargs)
if any(self._no_prepare(inp) for inp in inputs): # pragma: no cover
raise NotImplementedError('For now DataFrame, Tensor etc '
'cannot be passed to arguments')
return self.new_tileable(inputs)

@classmethod
def tile(cls, op):
out = op.outputs[0]

chunk_op = op.copy().reset_key()
chunk_params = out.params
chunk_params['index'] = ()

chunk_inputs = []
prepare_inputs = []
for inp in op.inputs:
if cls._no_prepare(inp): # pragma: no cover
# if input is tensor, DataFrame etc,
# do not prepare data, because the data mey be to huge,
# and users can choose to fetch slice of the data themselves
prepare_inputs.extend([False] * len(inp.chunks))
else:
prepare_inputs.extend([True] * len(inp.chunks))
chunk_inputs.extend(inp.chunks)
chunk_op._prepare_inputs = prepare_inputs
chunk = chunk_op.new_chunk(chunk_inputs, kws=[chunk_params])

new_op = op.copy()
params = out.params
params['chunks'] = [chunk]
params['nsplits'] = ()
return new_op.new_tileables(op.inputs, kws=[params])

@classmethod
def execute(cls, ctx, op):
from ..session import Session

session = ctx.get_current_session()
prev_default_session = Session.default

inputs_to_data = {inp: ctx[inp.key] for inp, prepare_inp
in zip(op.inputs, op.prepare_inputs) if prepare_inp}

try:
# set session created from context as default one
session.as_default()

function = op.function
function_args = replace_inputs(op.function_args, inputs_to_data)
function_kwargs = replace_inputs(op.function_kwargs, inputs_to_data)

result = function(*function_args, **function_kwargs)
ctx[op.outputs[0].key] = result
finally:
# set back default session
Session._set_default_session(prev_default_session)


def spawn(func, args=(), kwargs=None):
if not isinstance(args, tuple):
args = [args]
else:
args = list(args)
if kwargs is None:
kwargs = dict()
if not isinstance(kwargs, dict):
raise TypeError('kwargs has to be a dict')

op = RemoteFunction(function=func, function_args=args,
function_kwargs=kwargs)
return op()
13 changes: 13 additions & 0 deletions mars/remote/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
51 changes: 51 additions & 0 deletions mars/remote/tests/test_remote_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from mars.remote import spawn
from mars.tests.core import TestBase, ExecutorForTest


class Test(TestBase):
def setUp(self) -> None:
super().setUp()
self.executor = ExecutorForTest('numpy')
self.ctx, self.executor = self._create_test_context(self.executor)
self.ctx.__enter__()

def tearDown(self) -> None:
self.ctx.__exit__(None, None, None)

def testRemoteFunction(self):
def f1(x):
return x + 1

def f2(x, y, z=None):
return x * y * (z[0] + z[1])

rs = np.random.RandomState(0)
raw1 = rs.rand(10, 10)
raw2 = rs.rand(10, 10)

r1 = spawn(f1, raw1)
r2 = spawn(f1, raw2)
r3 = spawn(f2, (r1, r2), {'z': [r1, r2]})

result = self.executor.execute_tileables([r3])[0]
expected = (raw1 + 1) * (raw2 + 1) * (raw1 + 1 + raw2 + 1)
np.testing.assert_almost_equal(result, expected)

with self.assertRaises(TypeError):
spawn(f2, (r1, r2), kwargs=())
Loading

0 comments on commit 5fecd78

Please sign in to comment.