diff --git a/torchft/multiprocessing.py b/torchft/multiprocessing.py index 6e038a9..71b6aa3 100644 --- a/torchft/multiprocessing.py +++ b/torchft/multiprocessing.py @@ -1,95 +1,32 @@ import queue import time from datetime import timedelta +from multiprocessing.connection import Connection from typing import Union import torch.multiprocessing as mp -class _MonitoredQueue: - def __init__( - self, - p: mp.Process, - q: mp.Queue, - poll_interval: timedelta = timedelta(seconds=1), - ) -> None: - """ - Args: - p: process to monitor - q: queue to monitor - poll_interval: interval to poll the Process health when calling get/put - """ - self._p = p - self._q = q - self._poll_interval_s: float = poll_interval.total_seconds() +class _MonitoredPipe: + def __init__(self, pipe: "Connection[object, object]") -> None: + self._pipe = pipe - def get(self, timeout: Union[float, timedelta]) -> object: - """ - Get an item from the queue. If the process is not alive, raise RuntimeError. - If the queue is empty, wait for up to timeout seconds for an item to be - available. If no item is available after timeout seconds, raise TimeoutError. - - Args: - timeout: timeout in seconds - """ + def send(self, obj: object) -> None: + self._pipe.send(obj) + def recv(self, timeout: Union[float, timedelta]) -> object: if isinstance(timeout, timedelta): timeout = timeout.total_seconds() - - start = time.perf_counter() - while True: - try: - v = self._q.get(timeout=self._poll_interval_s) - break - except queue.Empty: - pass - - elapsed = time.perf_counter() - start - if elapsed > timeout: - raise TimeoutError(f"queue.get() timed out after {timeout} seconds") - - # polling the process can be slow so we only do it every poll_interval - if not self._p.is_alive(): - raise RuntimeError(f"process is not alive {self._p.exitcode}") - - if isinstance(v, Exception): - raise v - return v - - def put(self, obj: object, timeout: Union[float, timedelta]) -> None: - """ - Put an item into the queue. If the process is not alive, raise RuntimeError. - If the queue is full, wait for up to timeout seconds for an item to be - available. If queue is full after timeout seconds, raise TimeoutError. - - If an exception is put into the queue, it will be raised when calling get(). - - Args: - obj: object to put into the queue - timeout: timeout in seconds - """ - if isinstance(timeout, timedelta): - timeout = timeout.total_seconds() - - start = time.perf_counter() - while True: - try: - self._q.put(obj, timeout=self._poll_interval_s) - break - except queue.Full: - pass - - elapsed = time.perf_counter() - start - if elapsed > timeout: - raise TimeoutError(f"queue.put() timed out after {timeout} seconds") - - # polling the process can be slow so we only do it every poll_interval - if not self._p.is_alive(): - raise RuntimeError(f"process is not alive {self._p.exitcode}") + if self._pipe.poll(timeout): + out = self._pipe.recv() + if isinstance(out, Exception): + raise out + return out + else: + raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds") def close(self) -> None: - self._q.close() + self._pipe.close() def closed(self) -> bool: - # pyre-ignore[16]: no attribute _closed - return self._q._closed + return self._pipe.closed diff --git a/torchft/multiprocessing_test.py b/torchft/multiprocessing_test.py index 47459e2..61655f9 100644 --- a/torchft/multiprocessing_test.py +++ b/torchft/multiprocessing_test.py @@ -1,48 +1,55 @@ +from multiprocessing.connection import Connection from unittest import TestCase import torch.multiprocessing as mp -from torchft.multiprocessing import _MonitoredQueue +from torchft.multiprocessing import _MonitoredPipe -def queue_get(q: mp.Queue) -> None: - q.get() +def pipe_get(q: "Connection[object, object]") -> None: + q.recv() -def queue_put(q: mp.Queue) -> None: - q.put(1) +def pipe_put(q: "Connection[object, object]") -> None: + q.recv() + q.send(1) class MultiprocessingTest(TestCase): def test_monitored_queue_put(self) -> None: ctx = mp.get_context("fork") - q = ctx.Queue(maxsize=1) - p = ctx.Process(target=queue_get, args=(q,), daemon=True) + local, remote = ctx.Pipe() + p = ctx.Process(target=pipe_get, args=(remote,), daemon=True) p.start() + del remote - mq = _MonitoredQueue(p, q) - mq.put(1, timeout=10) - mq.put(1, timeout=10) - with self.assertRaisesRegex(RuntimeError, "process is not alive 0"): - mq.put(1, timeout=10) - - with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"): - mq.put(1, timeout=0.0) + mq = _MonitoredPipe(local) + mq.send(1) + with self.assertRaisesRegex(ConnectionResetError, "Connection reset by peer"): + while True: + mq.send(1) mq.close() + assert mq.closed() def test_monitored_queue_get(self) -> None: ctx = mp.get_context("fork") - q = ctx.Queue(maxsize=1) - p = ctx.Process(target=queue_put, args=(q,), daemon=True) + local, remote = ctx.Pipe() + p = ctx.Process(target=pipe_put, args=(remote,), daemon=True) p.start() + del remote - mq = _MonitoredQueue(p, q) - self.assertEqual(mq.get(timeout=10), 1) - with self.assertRaisesRegex(RuntimeError, "process is not alive 0"): - mq.get(timeout=10) + mq = _MonitoredPipe(local) with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"): - mq.get(timeout=0.0) + mq.recv(timeout=0.0) + + # continue + mq.send(1) + + self.assertEqual(mq.recv(timeout=10), 1) + with self.assertRaises(EOFError): + mq.recv(timeout=10) mq.close() + assert mq.closed() diff --git a/torchft/process_group.py b/torchft/process_group.py index 3ce4dcb..0b7507d 100644 --- a/torchft/process_group.py +++ b/torchft/process_group.py @@ -21,6 +21,7 @@ from contextlib import contextmanager, nullcontext from dataclasses import dataclass from datetime import timedelta +from multiprocessing.connection import Connection from typing import ( TYPE_CHECKING, Any, @@ -66,7 +67,7 @@ from torch.futures import Future from torch.utils._pytree import tree_any -from torchft.multiprocessing import _MonitoredQueue +from torchft.multiprocessing import _MonitoredPipe if TYPE_CHECKING: from torchft.manager import Manager @@ -329,6 +330,12 @@ def unregister(self) -> None: """ dist.destroy_process_group(self) + def shutdown(self) -> None: + """ + Shuts down the process group. + """ + pass + def __repr__(self) -> str: return f"{self.__class__.__name__}()" @@ -357,6 +364,10 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self._pg = self._create_pg(store, rank, world_size) + def shutdown(self) -> None: + # TODO: abort PG if possible + self._pg = None + def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup: raise NotImplementedError("not implemented") @@ -916,13 +927,14 @@ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None: self._world_size = -1 self._p: Optional[mp.Process] = None - self._tx: Optional[_MonitoredQueue] = None - self._rx: Optional[_MonitoredQueue] = None - self._future_queue: Optional[_MonitoredQueue] = None + self._pipe: Optional[_MonitoredPipe] = None + self._future_pipe: Optional[_MonitoredPipe] = None self._future_thread: Optional[threading.Thread] = None self._futures: Dict[int, Future[object]] = {} self._futures_lock = threading.Lock() + self._next_op_id = 0 + if isinstance(timeout, timedelta): timeout = timeout.total_seconds() @@ -938,24 +950,21 @@ def shutdown(self) -> None: ProcessGroup can be reconfigured after shutdown. """ - if self._tx is not None: - self._tx.close() - if self._rx is not None: - self._rx.close() + if self._pipe is not None: + self._pipe.close() - future_queue = self._future_queue - if future_queue is not None: + future_pipe = self._future_pipe + if future_pipe is not None: # wait for the future thread to exit and then close the queue - future_queue.put(_QUEUE_CLOSE, timeout=timedelta(seconds=10.0)) + future_pipe.close() future_thread = self._future_thread assert future_thread is not None + future_thread.join(timeout=10.0) if future_thread.is_alive(): raise RuntimeError("future thread did not exit") - future_queue.close() - # Kill after closing queues to avoid log spam. if self._p is not None: self._p.kill() @@ -966,9 +975,11 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: self.shutdown() ctx = mp.get_context("spawn") - tx = ctx.Queue() - rx = ctx.Queue() - future_queue = ctx.Queue() + req_local, req_remote = ctx.Pipe() + future_local, future_remote = ctx.Pipe() + + self._pipe = req_local = _MonitoredPipe(req_local) + self._future_pipe = future_local = _MonitoredPipe(future_local) self._p = p = ctx.Process( target=self._worker, @@ -976,32 +987,27 @@ def configure(self, store_addr: str, rank: int, world_size: int) -> None: store_addr, rank, world_size, - tx, - rx, - future_queue, + req_remote, + future_remote, ), daemon=True, ) p.start() - self._tx = tx = _MonitoredQueue(p, tx) - self._rx = rx = _MonitoredQueue(p, rx) - self._future_queue = future_queue = _MonitoredQueue(p, future_queue) - # futures need thread to fire callbacks # this lock needs to be held when manipulating _futures self._futures_lock = threading.Lock() self._futures = {} self._future_thread = threading.Thread( target=self._future_handler, - args=(future_queue,), + args=(future_local,), daemon=True, ) self._future_thread.start() # fetch the status of the PG init # if an exception was returned get will throw - assert rx.get(self._timeout) is None + assert req_local.recv(self._timeout) is None @classmethod def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup: @@ -1016,9 +1022,8 @@ def _worker( store_addr: str, rank: int, world_size: int, - rx: mp.Queue, - tx: mp.Queue, - future_queue: mp.Queue, + req_pipe: "Connection[object, object]", + future_pipe: "Connection[object, object]", ) -> None: try: store = create_store_client(store_addr) @@ -1027,19 +1032,32 @@ def _worker( pg = cls._create_pg(store, rank, world_size) except Exception as e: logger.exception(f"got exception in worker: {e}") - tx.put(e) + req_pipe.send(e) return - tx.put(None) + req_pipe.send(None) streams: Dict[str, torch.cuda.Stream] = {} work: Dict[int, _OpMetadata] = {} - next_op_id: int = 0 while True: - op = rx.get() + op = cast(list[object], req_pipe.recv()) cmd = op[0] if cmd == "func": - func_name, args, kwargs, stream_device, stream_id, event = op[1:] + op_id: int + op_id, func_name, args, kwargs, stream_device, stream_id, event = ( + cast( + Tuple[ + int, + str, + list[object], + dict[str, object], + int, + int, + Optional[torch.cuda.Event], + ], + op[1:], + ) + ) # To avoid potential deadlocks we need to preserve the # stream/synchronization behavior of the parent process. @@ -1068,15 +1086,12 @@ def _worker( args = _PickleSafeOptions.unsafe_args(args) fn = getattr(pg, func_name) - work[next_op_id] = _OpMetadata( + work[op_id] = _OpMetadata( work=fn(*args, **kwargs), stream=stream, ) - tx.put(next_op_id) - next_op_id += 1 elif cmd == "wait": - op_id: int = op[1] - timeout: Optional[timedelta] = op[2] + op_id, timeout = cast(tuple[int, timedelta], op[1:]) metadata = work[op_id] @@ -1098,42 +1113,39 @@ def _worker( else None ) - tx.put((op_id, event)) + req_pipe.send((op_id, event)) elif cmd == "del": - op_id: int = op[1] + op_id: int = cast(int, op[1]) del work[op_id] elif cmd == "future": - op_id: int = op[1] + op_id: int = cast(int, op[1]) def callback(fut: Future[object]) -> None: try: fut.wait() - future_queue.put((op_id, _FUTURE_RESULT, None)) + future_pipe.send((op_id, _FUTURE_RESULT, None)) except Exception as e: - future_queue.put((op_id, _FUTURE_EXCEPTION, e)) + future_pipe.send((op_id, _FUTURE_EXCEPTION, e)) work[op_id].work.get_future().add_done_callback(callback) - tx.put(op_id) elif cmd == "num_active_work": - tx.put(len(work)) + req_pipe.send(len(work)) else: raise ValueError(f"unknown cmd: {cmd}") except Exception as e: logger.exception(f"worker errored: {e}") - tx.put(e) + req_pipe.send(e) raise - def _future_handler(self, future_queue: _MonitoredQueue) -> None: + def _future_handler(self, future_pipe: _MonitoredPipe) -> None: try: while True: try: - # timeout doesn't really matter here - cmd = future_queue.get(timeout=timedelta(seconds=10.0)) + cmd = future_pipe.recv(timedelta(seconds=10)) except TimeoutError: continue - if cmd == _QUEUE_CLOSE: - break + op_id, mode, data = cast(Tuple[int, str, object], cmd) with self._futures_lock: fut = self._futures[op_id] @@ -1151,22 +1163,20 @@ def _get_future(self, op_id: int) -> Future[object]: with self._futures_lock: fut = Future() # pyre-fixme[29]: is not a function self._futures[op_id] = fut - assert self._tx is not None - self._tx.put(("future", op_id), timeout=self._timeout) + assert self._pipe is not None + self._pipe.send(("future", op_id)) - assert self._rx is not None - assert self._rx.get(self._timeout) == op_id # TODO: return correct tensor instead of None return fut def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: - assert self._tx is not None - self._tx.put(("wait", op_id, timeout), timeout=self._timeout) + assert self._pipe is not None + self._pipe.send(("wait", op_id, timeout)) - assert self._rx is not None + assert self._pipe is not None op_id, event = cast( Tuple[int, Optional[torch.cuda.Event]], - self._rx.get(timeout or self._timeout), + self._pipe.recv(timeout or self._timeout), ) assert op_id == op_id if event is not None: @@ -1175,14 +1185,12 @@ def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool: return True def _del(self, op_id: int) -> None: - assert self._tx is not None - self._tx.put(("del", op_id), timeout=self._timeout) + assert self._pipe is not None + self._pipe.send(("del", op_id)) def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: - rx = self._rx - tx = self._tx - assert rx is not None - assert tx is not None + pipe = self._pipe + assert pipe is not None is_cuda = _is_any_cuda(args) @@ -1196,9 +1204,13 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: else None ) - tx.put( + op_id = self._next_op_id + self._next_op_id += 1 + + pipe.send( ( "func", + op_id, func, _PickleSafeOptions.safe_args(args), kwargs, @@ -1206,12 +1218,8 @@ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work: stream_id, event, ), - timeout=self._timeout, ) - op_id = rx.get(self._timeout) - assert isinstance(op_id, int), f"invalid return {op_id}" - return _BabyWork(pg=self, op_id=op_id) def allgather( @@ -1329,11 +1337,11 @@ def size(self) -> int: return self._world_size def num_active_work(self) -> int: - assert self._tx is not None - self._tx.put(("num_active_work",), timeout=self._timeout) + assert self._pipe is not None + self._pipe.send(("num_active_work",)) - assert self._rx is not None - return cast(int, self._rx.get(self._timeout)) + assert self._pipe is not None + return cast(int, self._pipe.recv(self._timeout)) @dataclass diff --git a/torchft/process_group_test.py b/torchft/process_group_test.py index 236f773..29fd3ad 100644 --- a/torchft/process_group_test.py +++ b/torchft/process_group_test.py @@ -574,30 +574,30 @@ def test_reconfigure_baby_process_group(self) -> None: a = ProcessGroupBabyGloo() a.configure(store_addr, 0, 1) future_thread_1 = a._future_thread - future_queue_1 = a._future_queue + future_pipe_1 = a._future_pipe p_1 = a._p store_addr = f"localhost:{store.port}/prefix2" a.configure(store_addr, 0, 1) future_thread_2 = a._future_thread - future_queue_2 = a._future_queue + future_pipe_2 = a._future_pipe p_2 = a._p self.assertNotEqual(future_thread_1, future_thread_2) - self.assertNotEqual(future_queue_1, future_queue_2) + self.assertNotEqual(future_pipe_1, future_pipe_2) self.assertNotEqual(p_1, p_2) assert future_thread_1 is not None self.assertFalse(future_thread_1.is_alive()) - assert future_queue_1 is not None - self.assertTrue(future_queue_1.closed()) + assert future_pipe_1 is not None + self.assertTrue(future_pipe_1.closed()) assert p_1 is not None self.assertFalse(p_1.is_alive()) assert future_thread_2 is not None self.assertTrue(future_thread_2.is_alive()) - assert future_queue_2 is not None - self.assertFalse(future_queue_2.closed()) + assert future_pipe_2 is not None + self.assertFalse(future_pipe_2.closed()) assert p_2 is not None self.assertTrue(p_2.is_alive()) @@ -609,14 +609,22 @@ def test_baby_gloo_apis(self) -> None: store_addr = f"localhost:{store.port}/prefix" a = ProcessGroupBabyGloo(timeout=timedelta(seconds=10)) - a.configure(store_addr, 0, 1) + try: + a.configure(store_addr, 0, 1) - _test_pg(a) + _test_pg(a) - # force collection to ensure no BabyWork objects remain - gc.collect() + # force collection to ensure no BabyWork objects remain + gc.collect() - self.assertEqual(a.num_active_work(), 0) + self.assertEqual(a.num_active_work(), 0) + + finally: + a.shutdown() + + t = torch.zeros(10) + with self.assertRaisesRegex(OSError, "handle is closed"): + a.allreduce([t], AllreduceOptions()).wait() # pyre-fixme[56]: Pyre was not able to infer the type of argument @skipUnless(torch.cuda.is_available(), "needs CUDA") @@ -648,6 +656,10 @@ def test_baby_nccl_apis(self) -> None: torch.cuda.synchronize() torch.cuda.empty_cache() + t = torch.zeros(10) + with self.assertRaisesRegex(OSError, "handle is closed"): + a.allreduce([t], AllreduceOptions()).wait() + def test_dummy(self) -> None: pg = ProcessGroupDummy(0, 1) m = nn.Linear(3, 4) @@ -884,6 +896,7 @@ def worker(pg: ProcessGroup, rank: int, dev: str) -> str: t1 = torch.tensor([rank + 1], device=dev, dtype=torch.float32) # Simulate failure on the fault rank, but other ranks should still succeed. if rank == fault_rank: + pg.shutdown() return f"Rank{rank} crashed" try: