Skip to content

ProcessGroupBaby: use pipe for improved performance #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 16 additions & 79 deletions torchft/multiprocessing.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,32 @@
import queue
import time
from datetime import timedelta
from multiprocessing.connection import Connection
from typing import Union

import torch.multiprocessing as mp


class _MonitoredQueue:
def __init__(
self,
p: mp.Process,
q: mp.Queue,
poll_interval: timedelta = timedelta(seconds=1),
) -> None:
"""
Args:
p: process to monitor
q: queue to monitor
poll_interval: interval to poll the Process health when calling get/put
"""
self._p = p
self._q = q
self._poll_interval_s: float = poll_interval.total_seconds()
class _MonitoredPipe:
def __init__(self, pipe: "Connection[object, object]") -> None:
self._pipe = pipe

def get(self, timeout: Union[float, timedelta]) -> object:
"""
Get an item from the queue. If the process is not alive, raise RuntimeError.
If the queue is empty, wait for up to timeout seconds for an item to be
available. If no item is available after timeout seconds, raise TimeoutError.

Args:
timeout: timeout in seconds
"""
def send(self, obj: object) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

send doesnt need a timeout?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's not an easy mechanism to handle timeouts on write and from my testing the messages are small enough that the messages effectively instantly send as they just copy into the kernel default buffer size

If the remote process exits this will fail with a BrokenPipe error as well so that case is covered

self._pipe.send(obj)

def recv(self, timeout: Union[float, timedelta]) -> object:
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
try:
v = self._q.get(timeout=self._poll_interval_s)
break
except queue.Empty:
pass

elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.get() timed out after {timeout} seconds")

# polling the process can be slow so we only do it every poll_interval
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")

if isinstance(v, Exception):
raise v
return v

def put(self, obj: object, timeout: Union[float, timedelta]) -> None:
"""
Put an item into the queue. If the process is not alive, raise RuntimeError.
If the queue is full, wait for up to timeout seconds for an item to be
available. If queue is full after timeout seconds, raise TimeoutError.

If an exception is put into the queue, it will be raised when calling get().

Args:
obj: object to put into the queue
timeout: timeout in seconds
"""
if isinstance(timeout, timedelta):
timeout = timeout.total_seconds()

start = time.perf_counter()
while True:
try:
self._q.put(obj, timeout=self._poll_interval_s)
break
except queue.Full:
pass

elapsed = time.perf_counter() - start
if elapsed > timeout:
raise TimeoutError(f"queue.put() timed out after {timeout} seconds")

# polling the process can be slow so we only do it every poll_interval
if not self._p.is_alive():
raise RuntimeError(f"process is not alive {self._p.exitcode}")
if self._pipe.poll(timeout):
out = self._pipe.recv()
if isinstance(out, Exception):
raise out
return out
else:
raise TimeoutError(f"pipe.recv() timed out after {timeout} seconds")

def close(self) -> None:
self._q.close()
self._pipe.close()

def closed(self) -> bool:
# pyre-ignore[16]: no attribute _closed
return self._q._closed
return self._pipe.closed
51 changes: 29 additions & 22 deletions torchft/multiprocessing_test.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
from multiprocessing.connection import Connection
from unittest import TestCase

import torch.multiprocessing as mp

from torchft.multiprocessing import _MonitoredQueue
from torchft.multiprocessing import _MonitoredPipe


def queue_get(q: mp.Queue) -> None:
q.get()
def pipe_get(q: "Connection[object, object]") -> None:
q.recv()


def queue_put(q: mp.Queue) -> None:
q.put(1)
def pipe_put(q: "Connection[object, object]") -> None:
q.recv()
q.send(1)


class MultiprocessingTest(TestCase):
def test_monitored_queue_put(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_get, args=(q,), daemon=True)
local, remote = ctx.Pipe()
p = ctx.Process(target=pipe_get, args=(remote,), daemon=True)
p.start()
del remote

mq = _MonitoredQueue(p, q)
mq.put(1, timeout=10)
mq.put(1, timeout=10)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.put(1, timeout=10)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.put(1, timeout=0.0)
mq = _MonitoredPipe(local)
mq.send(1)
with self.assertRaisesRegex(ConnectionResetError, "Connection reset by peer"):
while True:
mq.send(1)

mq.close()
assert mq.closed()

def test_monitored_queue_get(self) -> None:
ctx = mp.get_context("fork")
q = ctx.Queue(maxsize=1)
p = ctx.Process(target=queue_put, args=(q,), daemon=True)
local, remote = ctx.Pipe()
p = ctx.Process(target=pipe_put, args=(remote,), daemon=True)
p.start()
del remote

mq = _MonitoredQueue(p, q)
self.assertEqual(mq.get(timeout=10), 1)
with self.assertRaisesRegex(RuntimeError, "process is not alive 0"):
mq.get(timeout=10)
mq = _MonitoredPipe(local)

with self.assertRaisesRegex(TimeoutError, "timed out after 0.0 seconds"):
mq.get(timeout=0.0)
mq.recv(timeout=0.0)

# continue
mq.send(1)

self.assertEqual(mq.recv(timeout=10), 1)
with self.assertRaises(EOFError):
mq.recv(timeout=10)

mq.close()
assert mq.closed()
Loading