Skip to content

Commit

Permalink
Refactor run methods more into abstract method (#4353)
Browse files Browse the repository at this point in the history
* First refactor step to make _run() implementations more similar

Signed-off-by: Merel Theisen <[email protected]>

* Move common logic to abstract _run method, use executor for sequential runner as well

Signed-off-by: Merel Theisen <[email protected]>

* Refactor max workers logic into shared helper method

Signed-off-by: Merel Theisen <[email protected]>

* Add resume scenario logic

Signed-off-by: Merel Theisen <[email protected]>

* Small cleanup

Signed-off-by: Merel Theisen <[email protected]>

* Clean up

Signed-off-by: Merel Theisen <[email protected]>

* Fix mypy checks

Signed-off-by: Merel Theisen <[email protected]>

* Fix sequential runner test

Signed-off-by: Merel Theisen <[email protected]>

* Fix thread runner

Signed-off-by: Merel Theisen <[email protected]>

* Ignore coverage for abstract method

Signed-off-by: Merel Theisen <[email protected]>

* Try fix thread runner test on 3.13

Signed-off-by: Merel Theisen <[email protected]>

* Fix thread runner test

Signed-off-by: Merel Theisen <[email protected]>

* Fix sequential runner test on windows

Signed-off-by: Merel Theisen <[email protected]>

* More flexible options for resume suggestion in thread runner tests

Signed-off-by: Merel Theisen <[email protected]>

* Clean up + make resume tests the same

Signed-off-by: Merel Theisen <[email protected]>

* Update tests/runner/test_sequential_runner.py

Signed-off-by: Merel Theisen <[email protected]>

* Clean up

Signed-off-by: Merel Theisen <[email protected]>

* Address review comments

Signed-off-by: Merel Theisen <[email protected]>

* Apply suggestions from code review

Co-authored-by: Ivan Danov <[email protected]>
Signed-off-by: Merel Theisen <[email protected]>

* Fix lint

Signed-off-by: Merel Theisen <[email protected]>

---------

Signed-off-by: Merel Theisen <[email protected]>
Signed-off-by: Merel Theisen <[email protected]>
Co-authored-by: Ivan Danov <[email protected]>
  • Loading branch information
merelcht and idanov authored Dec 10, 2024
1 parent 50e0cb5 commit f6ecca5
Show file tree
Hide file tree
Showing 7 changed files with 247 additions and 158 deletions.
80 changes: 12 additions & 68 deletions kedro/runner/parallel_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@

from __future__ import annotations

import os
import sys
from collections import Counter
from concurrent.futures import FIRST_COMPLETED, ProcessPoolExecutor, wait
from itertools import chain
from concurrent.futures import Executor, ProcessPoolExecutor
from multiprocessing.managers import BaseProxy, SyncManager
from multiprocessing.reduction import ForkingPickler
from pickle import PicklingError
Expand All @@ -21,7 +17,6 @@
SharedMemoryDataset,
)
from kedro.runner.runner import AbstractRunner
from kedro.runner.task import Task

if TYPE_CHECKING:
from collections.abc import Iterable
Expand All @@ -31,9 +26,6 @@
from kedro.pipeline import Pipeline
from kedro.pipeline.node import Node

# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114
_MAX_WINDOWS_WORKERS = 61


class ParallelRunnerManager(SyncManager):
"""``ParallelRunnerManager`` is used to create shared ``MemoryDataset``
Expand Down Expand Up @@ -83,16 +75,7 @@ def __init__(
self._manager = ParallelRunnerManager()
self._manager.start()

# This code comes from the concurrent.futures library
# https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L588
if max_workers is None:
# NOTE: `os.cpu_count` might return None in some weird cases.
# https://github.com/python/cpython/blob/3.7/Modules/posixmodule.c#L11431
max_workers = os.cpu_count() or 1
if sys.platform == "win32":
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers)

self._max_workers = max_workers
self._max_workers = self._validate_max_workers(max_workers)

def __del__(self) -> None:
self._manager.shutdown()
Expand Down Expand Up @@ -189,14 +172,17 @@ def _get_required_workers_count(self, pipeline: Pipeline) -> int:

return min(required_processes, self._max_workers)

def _get_executor(self, max_workers: int) -> Executor:
return ProcessPoolExecutor(max_workers=max_workers)

def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""The abstract interface for running pipelines.
"""The method implementing parallel pipeline running.
Args:
pipeline: The ``Pipeline`` to run.
Expand All @@ -218,50 +204,8 @@ def _run(
"for potential performance gains. https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously"
)

nodes = pipeline.nodes
self._validate_catalog(catalog, pipeline)
self._validate_nodes(nodes)
self._set_manager_datasets(catalog, pipeline)
load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))
node_dependencies = pipeline.node_dependencies
todo_nodes = set(node_dependencies.keys())
done_nodes: set[Node] = set()
futures = set()
done = None
max_workers = self._get_required_workers_count(pipeline)

with ProcessPoolExecutor(max_workers=max_workers) as pool:
while True:
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes}
todo_nodes -= ready
for node in ready:
task = Task(
node=node,
catalog=catalog,
is_async=self._is_async,
session_id=session_id,
parallel=True,
)
futures.add(pool.submit(task))
if not futures:
if todo_nodes:
debug_data = {
"todo_nodes": todo_nodes,
"done_nodes": done_nodes,
"ready_nodes": ready,
"done_futures": done,
}
debug_data_str = "\n".join(
f"{k} = {v}" for k, v in debug_data.items()
)
raise RuntimeError(
f"Unable to schedule new tasks although some nodes "
f"have not been run:\n{debug_data_str}"
)
break # pragma: no cover
done, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in done:
node = future.result()
done_nodes.add(node)

self._release_datasets(node, catalog, load_counts, pipeline)
super()._run(
pipeline=pipeline,
catalog=catalog,
session_id=session_id,
)
130 changes: 125 additions & 5 deletions kedro/runner/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,26 @@

import inspect
import logging
import os
import sys
import warnings
from abc import ABC, abstractmethod
from collections import deque
from collections import Counter, deque
from concurrent.futures import FIRST_COMPLETED, Executor, ProcessPoolExecutor, wait
from itertools import chain
from typing import TYPE_CHECKING, Any

from pluggy import PluginManager

from kedro import KedroDeprecationWarning
from kedro.framework.hooks.manager import _NullPluginManager
from kedro.io import CatalogProtocol, MemoryDataset, SharedMemoryDataset
from kedro.pipeline import Pipeline
from kedro.runner.task import Task

# see https://github.com/python/cpython/blob/master/Lib/concurrent/futures/process.py#L114
_MAX_WINDOWS_WORKERS = 61

if TYPE_CHECKING:
from collections.abc import Collection, Iterable

Expand Down Expand Up @@ -166,25 +175,95 @@ def run_only_missing(

return self.run(to_rerun, catalog, hook_manager)

@abstractmethod # pragma: no cover
def _get_executor(self, max_workers: int) -> Executor:
"""Abstract method to provide the correct executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor)."""
pass

@abstractmethod # pragma: no cover
def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""The abstract interface for running pipelines, assuming that the
inputs have already been checked and normalized by run().
inputs have already been checked and normalized by run().
This contains the Common pipeline execution logic using an executor.
Args:
pipeline: The ``Pipeline`` to run.
catalog: An implemented instance of ``CatalogProtocol`` from which to fetch data.
hook_manager: The ``PluginManager`` to activate hooks.
session_id: The id of the session.
"""
pass

nodes = pipeline.nodes

self._validate_catalog(catalog, pipeline)
self._validate_nodes(nodes)
self._set_manager_datasets(catalog, pipeline)

load_counts = Counter(chain.from_iterable(n.inputs for n in pipeline.nodes))
node_dependencies = pipeline.node_dependencies
todo_nodes = set(node_dependencies.keys())
done_nodes: set[Node] = set()
futures = set()
done = None
max_workers = self._get_required_workers_count(pipeline)

with self._get_executor(max_workers) as pool:
while True:
ready = {n for n in todo_nodes if node_dependencies[n] <= done_nodes}
todo_nodes -= ready
for node in ready:
task = Task(
node=node,
catalog=catalog,
hook_manager=hook_manager,
is_async=self._is_async,
session_id=session_id,
)
if isinstance(pool, ProcessPoolExecutor):
task.parallel = True
futures.add(pool.submit(task))
if not futures:
if todo_nodes:
self._raise_runtime_error(todo_nodes, done_nodes, ready, done)
break
done, futures = wait(futures, return_when=FIRST_COMPLETED)
for future in done:
try:
node = future.result()
except Exception:
self._suggest_resume_scenario(pipeline, done_nodes, catalog)
raise
done_nodes.add(node)
self._logger.info("Completed node: %s", node.name)
self._logger.info(
"Completed %d out of %d tasks", len(done_nodes), len(nodes)
)
self._release_datasets(node, catalog, load_counts, pipeline)

@staticmethod
def _raise_runtime_error(
todo_nodes: set[Node],
done_nodes: set[Node],
ready: set[Node],
done: set[Node] | None,
) -> None:
debug_data = {
"todo_nodes": todo_nodes,
"done_nodes": done_nodes,
"ready_nodes": ready,
"done_futures": done,
}
debug_data_str = "\n".join(f"{k} = {v}" for k, v in debug_data.items())
raise RuntimeError(
f"Unable to schedule new tasks although some nodes "
f"have not been run:\n{debug_data_str}"
)

def _suggest_resume_scenario(
self,
Expand Down Expand Up @@ -240,6 +319,47 @@ def _release_datasets(
if load_counts[dataset] < 1 and dataset not in pipeline.outputs():
catalog.release(dataset)

def _validate_catalog(self, catalog: CatalogProtocol, pipeline: Pipeline) -> None:
# Add catalog validation logic here if needed
pass

def _validate_nodes(self, node: Iterable[Node]) -> None:
# Add node validation logic here if needed
pass

def _set_manager_datasets(
self, catalog: CatalogProtocol, pipeline: Pipeline
) -> None:
# Set up any necessary manager datasets here
pass

def _get_required_workers_count(self, pipeline: Pipeline) -> int:
return 1

@classmethod
def _validate_max_workers(cls, max_workers: int | None) -> int:
"""
Validates and returns the number of workers. Sets to os.cpu_count() or 1 if max_workers is None,
and limits max_workers to 61 on Windows.
Args:
max_workers: Desired number of workers. If None, defaults to os.cpu_count() or 1.
Returns:
A valid number of workers to use.
Raises:
ValueError: If max_workers is set and is not positive.
"""
if max_workers is None:
max_workers = os.cpu_count() or 1
if sys.platform == "win32":
max_workers = min(_MAX_WINDOWS_WORKERS, max_workers)
elif max_workers <= 0:
raise ValueError("max_workers should be positive")

return max_workers


def _find_nodes_to_resume_from(
pipeline: Pipeline, unfinished_nodes: Collection[Node], catalog: CatalogProtocol
Expand Down
44 changes: 16 additions & 28 deletions kedro/runner/sequential_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from __future__ import annotations

from collections import Counter
from itertools import chain
from concurrent.futures import (
Executor,
ThreadPoolExecutor,
)
from typing import TYPE_CHECKING, Any

from kedro.runner.runner import AbstractRunner
from kedro.runner.task import Task

if TYPE_CHECKING:
from pluggy import PluginManager
Expand Down Expand Up @@ -46,11 +47,16 @@ def __init__(
is_async=is_async, extra_dataset_patterns=self._extra_dataset_patterns
)

def _get_executor(self, max_workers: int) -> Executor:
return ThreadPoolExecutor(
max_workers=1
) # Single-threaded for sequential execution

def _run(
self,
pipeline: Pipeline,
catalog: CatalogProtocol,
hook_manager: PluginManager,
hook_manager: PluginManager | None = None,
session_id: str | None = None,
) -> None:
"""The method implementing sequential pipeline running.
Expand All @@ -69,27 +75,9 @@ def _run(
"Using synchronous mode for loading and saving data. Use the --async flag "
"for potential performance gains. https://docs.kedro.org/en/stable/nodes_and_pipelines/run_a_pipeline.html#load-and-save-asynchronously"
)
nodes = pipeline.nodes
done_nodes = set()

load_counts = Counter(chain.from_iterable(n.inputs for n in nodes))

for exec_index, node in enumerate(nodes):
try:
Task(
node=node,
catalog=catalog,
hook_manager=hook_manager,
is_async=self._is_async,
session_id=session_id,
).execute()
done_nodes.add(node)
except Exception:
self._suggest_resume_scenario(pipeline, done_nodes, catalog)
raise

self._release_datasets(node, catalog, load_counts, pipeline)

self._logger.info(
"Completed %d out of %d tasks", len(done_nodes), len(nodes)
)
super()._run(
pipeline=pipeline,
catalog=catalog,
hook_manager=hook_manager,
session_id=session_id,
)
Loading

0 comments on commit f6ecca5

Please sign in to comment.