Skip to content

Commit

Permalink
ProcessGroupBaby: use pipe for improved performance (#121)
Browse files Browse the repository at this point in the history
  • Loading branch information
d4l3k authored Mar 6, 2025
1 parent 082753c commit 2ab329e
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 187 deletions.
95 changes: 16 additions & 79 deletions torchft/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,32 @@
import queue
import time
from datetime import timedelta
from multiprocessing.connection import Connection
from typing import Union

import torch.multiprocessing as mp


class _MonitoredQueue:
def __init__(
self,
p: mp.Process,
q: mp.Queue,
poll_interval: timedelta = timedelta(seconds=1),
) -> None:
"""
Args:
p: process to monitor
q: queue to monitor
poll_interval: interval to poll the Process health when calling get/put
"""
self._p = p
self._q = q
self._poll_interval_s: float = poll_interval.total_seconds()
class _MonitoredPipe:
def __init__(self, pipe: "Connection[object, object]") -> None:
self._pipe = pipe

def get(self, timeout: Union[float, timedelta]) -> object:
"""
Get an item from the queue. If the process is not alive, raise RuntimeError.
If the queue is empty, wait for up to timeout seconds for an item to be
available. If no item is available after timeout seconds, raise TimeoutError.
Args:
timeout: timeout in seconds
"""
def send(self, obj: object) -> None:
self._pipe.send(obj)

def recv(self, timeout: Union[float, timedelta]) -> object:
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
try:
v = self._q.get(timeout=self._poll_interval_s)
break
except queue.Empty:
pass

elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")

# polling the process can be slow so we only do it every poll_interval
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")

if isinstance(v, Exception):
raise v
return v

def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
"""
Put an item into the queue. If the process is not alive, raise RuntimeError.
If the queue is full, wait for up to timeout seconds for an item to be
available. If queue is full after timeout seconds, raise TimeoutError.
If an exception is put into the queue, it will be raised when calling get().
Args:
obj: object to put into the queue
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
try:
self._q.put(obj, timeout=self._poll_interval_s)
break
except queue.Full:
pass

elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")

# polling the process can be slow so we only do it every poll_interval
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")
if self._pipe.poll(timeout):
out = self._pipe.recv()
if isinstance(out, Exception):
raise out
return out
else:
raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds")

def close(self) -> None:
self._q.close()
self._pipe.close()

def closed(self) -> bool:
# pyre-ignore[16]: no attribute _closed
return self._q._closed
return self._pipe.closed
51 changes: 29 additions & 22 deletions torchft/multiprocessing_test.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
from multiprocessing.connection import Connection
from unittest import TestCase

import torch.multiprocessing as mp

from torchft.multiprocessing import _MonitoredQueue
from torchft.multiprocessing import _MonitoredPipe


def queue_get(q: mp.Queue) -> None:
q.get()
def pipe_get(q: "Connection[object, object]") -> None:
q.recv()


def queue_put(q: mp.Queue) -> None:
q.put(1)
def pipe_put(q: "Connection[object, object]") -> None:
q.recv()
q.send(1)


class MultiprocessingTest(TestCase):
def test_monitored_queue_put(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
local, remote = ctx.Pipe()
p = ctx.Process(target=pipe_get, args=(remote,), daemon=True)
p.start()
del remote

mq = _MonitoredQueue(p, q)
mq.put(1, timeout=10)
mq.put(1, timeout=10)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.put(1, timeout=10)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.put(1, timeout=0.0)
mq = _MonitoredPipe(local)
mq.send(1)
with self.assertRaisesRegex(ConnectionResetError, "Connection reset by peer"):
while True:
mq.send(1)

mq.close()
assert mq.closed()

def test_monitored_queue_get(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
local, remote = ctx.Pipe()
p = ctx.Process(target=pipe_put, args=(remote,), daemon=True)
p.start()
del remote

mq = _MonitoredQueue(p, q)
self.assertEqual(mq.get(timeout=10), 1)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.get(timeout=10)
mq = _MonitoredPipe(local)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.get(timeout=0.0)
mq.recv(timeout=0.0)

# continue
mq.send(1)

self.assertEqual(mq.recv(timeout=10), 1)
with self.assertRaises(EOFError):
mq.recv(timeout=10)

mq.close()
assert mq.closed()
Loading

0 comments on commit 2ab329e

Please sign in to comment.