From 92b245f28ca013c2013d535ad4eb41d5ff6210db Mon Sep 17 00:00:00 2001 From: Wenjun Si Date: Thu, 7 Sep 2023 15:27:51 +0800 Subject: [PATCH] Add support for ray context --- mars/deploy/oscar/ray.py | 1 + mars/deploy/oscar/session.py | 2 + mars/deploy/oscar/tests/test_ray_dag.py | 31 ++++++++++ mars/lib/aio/isolation.py | 8 +++ mars/services/context.py | 4 +- mars/services/task/execution/ray/context.py | 23 +++++-- mars/services/task/execution/ray/executor.py | 64 +++++++++++++++----- mars/services/task/execution/ray/fetcher.py | 14 ++++- mars/session.py | 11 +++- mars/utils.py | 4 ++ 10 files changed, 137 insertions(+), 25 deletions(-) diff --git a/mars/deploy/oscar/ray.py b/mars/deploy/oscar/ray.py index 5f4d4ae650..7d23492062 100644 --- a/mars/deploy/oscar/ray.py +++ b/mars/deploy/oscar/ray.py @@ -389,6 +389,7 @@ def new_ray_session( client = new_cluster_in_ray(backend=backend, **new_cluster_kwargs) session_id = session_id or client.session.session_id address = client.address + logger.warning("CLIENT ADDRESS: %s", address) session = new_session( address=address, session_id=session_id, backend=backend, default=default ) diff --git a/mars/deploy/oscar/session.py b/mars/deploy/oscar/session.py index 218bdce7ad..d68a1d66eb 100644 --- a/mars/deploy/oscar/session.py +++ b/mars/deploy/oscar/session.py @@ -514,7 +514,9 @@ async def fetch(self, *tileables, **kwargs) -> list: chunks, chunk_metas, itertools.chain(*fetch_infos_list) ): await fetcher.append(chunk.key, meta, fetch_info.indexes) + logger.warning("FETCH!! %r", fetcher) fetched_data = await fetcher.get() + logger.warning("FETCH2!!") for fetch_info, data in zip( itertools.chain(*fetch_infos_list), fetched_data ): diff --git a/mars/deploy/oscar/tests/test_ray_dag.py b/mars/deploy/oscar/tests/test_ray_dag.py index c245cca8e6..eb52f8445c 100644 --- a/mars/deploy/oscar/tests/test_ray_dag.py +++ b/mars/deploy/oscar/tests/test_ray_dag.py @@ -13,12 +13,14 @@ # limitations under the License. import copy +import logging import os import time import pytest from .... import get_context +from .... import remote as mr from .... import tensor as mt from ....session import new_session, get_default_async_session from ....tests import test_session @@ -125,6 +127,35 @@ def test_sync_execute(ray_start_regular_shared2, config): test_local.test_sync_execute(config) +@require_ray +@pytest.mark.parametrize("config", [{"backend": "ray"}]) +def test_spawn_execution(ray_start_regular_shared2, config): + session = new_session( + backend=config["backend"], + n_cpu=2, + web=False, + use_uvloop=False, + config={"task.execution_config.ray.monitor_interval_seconds": 0}, + ) + + assert session._session.client.web_address is None + assert session.get_web_endpoint() is None + + def f1(c=0): + if c: + executed = mr.spawn(f1).execute() + logging.warning("EXECUTE DONE!") + executed.fetch() + logging.warning("FETCH DONE!") + return c + + with session: + assert 10 == mr.spawn(f1, 10).execute().fetch() + + session.stop_server() + assert get_default_async_session() is None + + @require_ray @pytest.mark.parametrize( "create_cluster", diff --git a/mars/lib/aio/isolation.py b/mars/lib/aio/isolation.py index 3adb61e0c4..2d2145cc75 100644 --- a/mars/lib/aio/isolation.py +++ b/mars/lib/aio/isolation.py @@ -14,9 +14,12 @@ import asyncio import atexit +import logging import threading from typing import Dict, Optional +logger = logging.getLogger(__name__) + class Isolation: loop: asyncio.AbstractEventLoop @@ -31,6 +34,9 @@ def __init__(self, loop: asyncio.AbstractEventLoop, threaded: bool = True): self._thread = None self._thread_ident = None + def __repr__(self): + return f"" + def _run(self): asyncio.set_event_loop(self.loop) self._stopped = asyncio.Event() @@ -72,9 +78,11 @@ def new_isolation( if loop is None: loop = asyncio.new_event_loop() + logger.warning("NEW_LOOP %d", id(loop)) isolation = Isolation(loop, threaded=threaded) isolation.start() + logger.warning("NEW_ISOLATION! loop: %r", loop) _name_to_isolation[name] = isolation return isolation diff --git a/mars/services/context.py b/mars/services/context.py index 596c9f0d13..21d4a233ee 100644 --- a/mars/services/context.py +++ b/mars/services/context.py @@ -47,6 +47,7 @@ def __init__( local_address: str, loop: asyncio.AbstractEventLoop, band: BandType = None, + isolation_threaded: bool = False, ): super().__init__( session_id=session_id, @@ -59,7 +60,8 @@ def __init__( # new isolation with current loop, # so that session created in tile and execute # can get the right isolation - new_isolation(loop=self._loop, threaded=False) + logger.warning("NEW_ISOLATION in ThreadedServiceContext.__init__") + new_isolation(loop=self._loop, threaded=isolation_threaded) self._running_session_id = None self._running_op_key = None diff --git a/mars/services/task/execution/ray/context.py b/mars/services/task/execution/ray/context.py index 80205b0baf..92beea82ea 100644 --- a/mars/services/task/execution/ray/context.py +++ b/mars/services/task/execution/ray/context.py @@ -17,8 +17,9 @@ from typing import Dict, List, Callable from .....core.context import Context +from .....session import ensure_isolation_created from .....storage.base import StorageLevel -from .....typing import ChunkType +from .....typing import ChunkType, SessionType from .....utils import implements, lazy_import, sync_to_async from ....context import ThreadedServiceContext from .config import RayExecutionConfig @@ -187,13 +188,27 @@ def get_worker_addresses(self) -> List[str]: # TODO(fyrestone): Implement more APIs for Ray. -class RayExecutionWorkerContext(_RayRemoteObjectContext, dict): +class RayExecutionWorkerContext(_RayRemoteObjectContext, ThreadedServiceContext, dict): """The context for executing operands.""" - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__( + self, + get_or_create_actor: Callable[[], "ray.actor.ActorHandle"], + *args, + **kwargs, + ): + _RayRemoteObjectContext.__init__(self, get_or_create_actor, *args, loop=None, isolation_threaded=True, **kwargs) + dict.__init__(self) self._current_chunk = None + @implements(Context.get_current_session) + def get_current_session(self) -> SessionType: + from .....session import new_session + + return new_session( + self.supervisor_address, self.session_id, backend="ray", new=False, default=False + ) + @classmethod @implements(Context.new_custom_log_dir) def new_custom_log_dir(cls): diff --git a/mars/services/task/execution/ray/executor.py b/mars/services/task/execution/ray/executor.py index 432180768e..1de8bb794a 100644 --- a/mars/services/task/execution/ray/executor.py +++ b/mars/services/task/execution/ray/executor.py @@ -19,6 +19,7 @@ import itertools import logging import operator +import os import time from dataclasses import dataclass, field from typing import List, Dict, Any, Callable @@ -26,7 +27,7 @@ import numpy as np from .....core import ChunkGraph, Chunk, TileContext -from .....core.context import set_context +from .....core.context import set_context, get_context from .....core.operand import ( Fetch, Fuse, @@ -38,6 +39,7 @@ from .....metrics.api import init_metrics, Metrics from .....resource import Resource from .....serialization import serialize, deserialize +from .....session import AbstractSession, get_default_session from .....typing import BandType from .....utils import ( aiotask_wrapper, @@ -149,10 +151,12 @@ def gc_inputs(self, chunk: Chunk): def execute_subtask( + session_id: str, subtask_id: str, subtask_chunk_graph: ChunkGraph, output_meta_n_keys: int, is_mapper, + address: str, *inputs, ): """ @@ -176,6 +180,9 @@ def execute_subtask( ------- subtask outputs and meta for outputs if `output_meta_keys` is provided. """ + logging.basicConfig(level=logging.INFO) + logger.setLevel(logging.INFO) + init_metrics("ray") started_subtask_number.record(1) ray_task_id = ray.get_runtime_context().get_task_id() @@ -184,7 +191,16 @@ def execute_subtask( # Optimize chunk graph. subtask_chunk_graph = _optimize_subtask_graph(subtask_chunk_graph) fetch_chunks, shuffle_fetch_chunk = _get_fetch_chunks(subtask_chunk_graph) - context = RayExecutionWorkerContext(RayTaskState.get_handle) + + context = RayExecutionWorkerContext( + RayTaskState.get_handle, + session_id, + address, + address, + address, + ) + set_context(context) + if shuffle_fetch_chunk is not None: # The subtask is a reducer subtask. n_mappers = shuffle_fetch_chunk.op.n_mappers @@ -209,19 +225,28 @@ def execute_subtask( # Update non shuffle inputs to context. context.update(zip((start_chunk.key for start_chunk in fetch_chunks), inputs)) - for chunk in subtask_chunk_graph.topological_iter(): - if chunk.key not in context: - try: - context.set_current_chunk(chunk) - execute(context, chunk.op) - except Exception: - logger.exception( - "Execute operand %s of graph %s failed.", - chunk.op, - subtask_chunk_graph.to_dot(), - ) - raise - subtask_gc.gc_inputs(chunk) + default_session = get_default_session() + try: + context.get_current_session().as_default() + + for chunk in subtask_chunk_graph.topological_iter(): + if chunk.key not in context: + try: + context.set_current_chunk(chunk) + execute(context, chunk.op) + except Exception: + logger.exception( + "Execute operand %s of graph %s failed.", + chunk.op, + subtask_chunk_graph.to_dot(), + ) + raise + subtask_gc.gc_inputs(chunk) + finally: + if default_session is not None: + default_session.as_default() + else: + AbstractSession.reset_default() # For non-mapper subtask, output context is chunk key to results. # For mapper subtasks, output context is data key to results. @@ -455,6 +480,7 @@ def __init__( task_chunks_meta: Dict[str, _RayChunkMeta], lifecycle_api: LifecycleAPI, meta_api: MetaAPI, + address: str, ): logger.info( "Start task %s with GC method %s.", @@ -475,6 +501,8 @@ def __init__( self._available_band_resources = None self._result_tileables_lifecycle = None + self._address = address + # For progress and task cancel self._stage_index = 0 self._pre_all_stages_progress = 0.0 @@ -507,6 +535,7 @@ async def create( task_chunks_meta, lifecycle_api, meta_api, + address, ) available_band_resources = await executor.get_available_band_resources() worker_addresses = list( @@ -710,10 +739,12 @@ async def _execute_subtask_graph( memory=subtask_memory, scheduling_strategy="DEFAULT" if len(input_object_refs) else "SPREAD", ).remote( + subtask.session_id, subtask.subtask_id, serialize(subtask_chunk_graph, context={"serializer": "ray"}), subtask.stage_n_outputs, is_mapper, + self._address, *input_object_refs, ) await asyncio.sleep(0) @@ -739,6 +770,7 @@ async def _execute_subtask_graph( task_context[chunk_key] = object_ref logger.info("Submitted %s subtasks of stage %s.", len(subtask_graph), stage_id) + logger.warning("SUBTASK_RUN_1") monitor_context.stage = _RayExecutionStage.WAITING key_to_meta = {} if len(output_meta_object_refs) > 0: @@ -752,6 +784,7 @@ async def _execute_subtask_graph( self._task_chunks_meta[key] = _RayChunkMeta(memory_size=memory_size) logger.info("Got %s metas of stage %s.", meta_count, stage_id) + logger.warning("SUBTASK_RUN_2") chunk_to_meta = {} # ray.wait requires the object ref list is unique. output_object_refs = set() @@ -773,6 +806,7 @@ async def _execute_subtask_graph( await asyncio.to_thread(ray.wait, list(output_object_refs), fetch_local=False) logger.info("Complete stage %s.", stage_id) + logger.warning("%d: SUBTASK_RUN_3: %r", os.getpid(), output_object_refs) return chunk_to_meta async def __aexit__(self, exc_type, exc_val, exc_tb): diff --git a/mars/services/task/execution/ray/fetcher.py b/mars/services/task/execution/ray/fetcher.py index f7efdf625d..addff69303 100644 --- a/mars/services/task/execution/ray/fetcher.py +++ b/mars/services/task/execution/ray/fetcher.py @@ -14,12 +14,15 @@ import asyncio import functools +import logging from collections import namedtuple from typing import Dict, List from .....utils import lazy_import from ..api import Fetcher, register_fetcher_cls +logger = logging.getLogger(__name__) + ray = lazy_import("ray") _FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"]) @@ -36,9 +39,10 @@ class RayFetcher(Fetcher): name = "ray" required_meta_keys = ("object_refs",) - def __init__(self, **kwargs): + def __init__(self, loop=None, **kwargs): self._fetch_info_list = [] self._no_conditions = True + self._loop = loop @staticmethod @functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster @@ -55,9 +59,12 @@ async def append(self, chunk_key: str, chunk_meta: Dict, conditions: List = None async def get(self): if self._no_conditions: + logger.warning(f"FETCHER_0 {self._fetch_info_list}") return await asyncio.gather( - *(info.object_ref for info in self._fetch_info_list) + *(info.object_ref for info in self._fetch_info_list), + loop=self._loop, ) + logger.warning("FETCHER_1") refs = [None] * len(self._fetch_info_list) for index, fetch_info in enumerate(self._fetch_info_list): if fetch_info.conditions is None: @@ -66,4 +73,5 @@ async def get(self): refs[index] = self._remote_query_object_with_condition().remote( fetch_info.object_ref, tuple(fetch_info.conditions) ) - return await asyncio.gather(*refs) + logger.warning("FETCHER_2") + return await asyncio.gather(*refs, loop=self._loop) diff --git a/mars/session.py b/mars/session.py index 5335dea476..881e78d4f7 100644 --- a/mars/session.py +++ b/mars/session.py @@ -177,7 +177,7 @@ def reset_default(cls): AbstractSession._default = None @classproperty - def default(self): + def default(self) -> Optional["AbstractSession"]: return AbstractSession._default @@ -912,12 +912,19 @@ def init( **kwargs, ) -> "AbstractSession": isolation = ensure_isolation_created(kwargs) + logger.warning( + "ISOLATION INFO: %r, cur_thread_ident: %s", + isolation, + threading.current_thread().ident, + ) coro = _get_isolated_session_cls(address).init( address, session_id, backend, new=new, **kwargs ) + logger.warning("CORO INFO: %r, address: %s, kw: %r, loop %d%r", coro, address, kwargs, id(isolation.loop), isolation.loop) fut = asyncio.run_coroutine_threadsafe(coro, isolation.loop) isolated_session = fut.result() - return SyncSession(address, session_id, isolated_session, isolation) + session = SyncSession(address, session_id, isolated_session, isolation) + return session def as_default(self) -> AbstractSession: AbstractSession._default = self._isolated_session diff --git a/mars/utils.py b/mars/utils.py index abcc24c38f..c155866bfc 100644 --- a/mars/utils.py +++ b/mars/utils.py @@ -1217,12 +1217,16 @@ def wrapped(cls, ctx, op): if _enter_counter == 0: # to handle nested call, only set initial session # in first call + logger.warning("Context: %s", type(ctx)) session = ctx.get_current_session() _initial_session = get_default_session() session.as_default() + logger.warning("session %r loop: %d%r", session, id(session._loop), session._loop) + logger.warning("default_session loop 00: %d%r", id(get_default_session()._loop), get_default_session()._loop) _enter_counter += 1 try: + logger.warning("default_session loop: %d%r", id(get_default_session()._loop), get_default_session()._loop) result = func(cls, ctx, op) finally: with AbstractSession._lock: