Skip to content

Commit

Permalink
✨ add RayIOManager (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni authored Oct 10, 2024
1 parent 734d723 commit db03609
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 3 deletions.
86 changes: 85 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

- `ray_executor` - an `Executor` which submits individual Dagster steps as isolated Ray jobs (in cluster mode) to a Ray cluster.

- `RayIOManager` - an `IOManager` which allows storing and retrieving intermediate values in Ray's object store. Ideal in conjunction with `RayRunLauncher` and `ray_executor`.

- `PipesKubeRayJobClient`, a [Dagster Pipes](https://docs.dagster.io/concepts/dagster-pipes) client for launching and monitoring [KubeRay](https://github.com/ray-project/kuberay)'s `RayJob` CR in Kubernetes. Typically used with external Pythons scripts. Allows receiving rich logs, events and metadata from the job.

- `RayResource`, a resource representing a Ray cluster. Interactions with Ray are performed in **client mode** (requires stable persistent connection), so it's most suitable for relatively short-lived jobs. It has implementations for `KubeRay` and local (mostly for testing) backends. `dagster_ray.RayResource` defines the common interface shared by all backends and can be used for backend-agnostic type annotations.
Expand Down Expand Up @@ -99,7 +101,6 @@ def my_job():
return my_op()
```
# Executor
> [!WARNING]
Expand Down Expand Up @@ -144,6 +145,89 @@ def my_job():

Fields in the `dagster-ray/config` tag **override** corresponding fields in the Executor config.


## IOManager

`RayIOManager` allows storing and retrieving intermediate values in Ray's object store. It can be used in conjunction with `RayRunLauncher` and `ray_executor` to store and retrieve intermediate values in a Ray cluster.

It works by storing Dagster step keys in a global Ray actor. This actor contains a mapping between step keys and Ray `ObjectRef`s. It can be used with any pickable Python objects.




```python
from dagster import asset, Definitions
from dagster_ray import RayIOManager
@asset(io_manager_key="ray_io_manager")
def upstream() -> int:
return 42
@asset
def downstream(upstream: int):
return 0
definitions = Definitions(
assets=[upstream, downstream], resources={"ray_io_manager": RayIOManager()}
)
```

It supports partitioned assets.


```python
from dagster import (
asset,
Definitions,
StaticPartitionsDefinition,
AssetExecutionContext,
)
from dagster_ray import RayIOManager
partitions_def = StaticPartitionsDefinition(["a", "b", "c"])
@asset(io_manager_key="ray_io_manager", partitions_def=partitions_def)
def upstream(context: AssetExecutionContext):
return context.partition_key
@asset(partitions_def=partitions_def)
def downstream(context: AssetExecutionContext, upstream: str) -> None:
assert context.partition_key == upstream
```


It supports partition mappings. When loading **multiple** upstream partitions, they should be annotated with a `Dict[str, ...]`, `dict[str, ...]`, or `Mapping[str, ...]` type hint.


```python
from dagster import (
asset,
Definitions,
StaticPartitionsDefinition,
AssetExecutionContext,
)
from dagster_ray import RayIOManager
partitions_def = StaticPartitionsDefinition(["A", "B", "C"])
@asset(io_manager_key="ray_io_manager", partitions_def=partitions_def)
def upstream(context: AssetExecutionContext):
return context.partition_key.lower()
@asset
def downstream_unpartitioned(upstream: Dict[str, str]) -> None:
assert upstream == {"A": "a", "B": "b", "C": "c"}
```

# Backends

## KubeRay
Expand Down
3 changes: 2 additions & 1 deletion dagster_ray/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from dagster_ray._base.resources import BaseRayResource
from dagster_ray.executor import ray_executor
from dagster_ray.io_manager import RayIOManager
from dagster_ray.run_launcher import RayRunLauncher

RayResource = BaseRayResource


__all__ = ["RayResource", "RayRunLauncher", "ray_executor"]
__all__ = ["RayResource", "RayRunLauncher", "RayIOManager", "ray_executor"]
123 changes: 123 additions & 0 deletions dagster_ray/io_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from typing import TYPE_CHECKING, Dict, Optional, Union

from dagster import ConfigurableIOManager, ConfigurableResource, InputContext, OutputContext

DAGSTER_RAY_OBJECT_MAP_NAME = "DagsterRayObjectMap"
DAGSTER_RAY_NAMESPACE = "dagster-ray"

# we need to create a global Ray actor which will store all the refs to all objcets

if TYPE_CHECKING:
import ray


class RayObjectMap:
# TODO: implement some eventual cleanup mechanism
# idea: save creation timestamp and periodically check for old refs
# or add some integration with the RunLauncher/Executor
def __init__(self):
self._object_map: Dict[str, "ray.ObjectRef"] = {}

def set(self, key: str, ref: "ray.ObjectRef"):
self._object_map[key] = ref

def get(self, key: str) -> Optional["ray.ObjectRef"]:
return self._object_map.get(key)

def delete(self, key: str):
if key in self._object_map:
del self._object_map[key]

def keys(self):
return self._object_map.keys()

def ping(self):
return "pong"

@staticmethod
def get_or_create():
import ray

actor = (
ray.remote(RayObjectMap)
.options( # type: ignore
name=DAGSTER_RAY_OBJECT_MAP_NAME,
namespace=DAGSTER_RAY_NAMESPACE,
get_if_exists=True,
lifetime="detached",
# max_restarts=-1,
max_concurrency=1000, # TODO: make this configurable,
runtime_env={"RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING": "1"},
)
.remote()
)

# make sure the actor is created
ray.get(actor.ping.remote())

return actor


class RayIOManager(ConfigurableIOManager, ConfigurableResource):
address: Optional[str] = None

def handle_output(self, context: OutputContext, obj):
import ray

# if self.address: # TODO: should this really be done here?
# ray.init(self.address, ignore_reinit_error=True)

object_map = RayObjectMap.get_or_create()

storage_key = self._get_single_key(context)

# TODO: understand if Ray will automatically move the object from dying nodes
# what if not?

ref = ray.put(obj, _owner=object_map)

object_map.set.remote(storage_key, ref)

context.log.debug(f"Stored object with key {storage_key} as {ref}")

def load_input(self, context: InputContext):
import ray

# if self.address: # TODO: should this really be done here?
# ray.init(self.address, ignore_reinit_error=True)

object_map = RayObjectMap.get_or_create()

if context.has_asset_partitions and len(context.asset_partition_keys) > 1:
# load multiple partitions as once
# first, get the refs

storage_keys = self._get_multiple_keys(context)
refs = [object_map.get.remote(key) for key in storage_keys.values()]
values = ray.get(refs)
return {partition_key: value for partition_key, value in zip(storage_keys.keys(), values)}

else:
storage_key = self._get_single_key(context)

context.log.debug(f"Loading object with key {storage_key}")

ref = object_map.get.remote(storage_key)

assert ref is not None, f"Object with key {storage_key} not found in RayObjectMap"

return ray.get(ref)

def _get_single_key(self, context: Union[InputContext, OutputContext]) -> str:
identifier = context.get_identifier() if not context.has_asset_key else context.get_asset_identifier()
return "/".join(identifier)

def _get_multiple_keys(self, context: InputContext) -> Dict[str, str]:
if context.has_asset_key:
asset_path = list(context.asset_key.path)

return {
partition_key: "/".join(asset_path + [partition_key]) for partition_key in context.asset_partition_keys
}
else:
raise RuntimeError("This method can only be called with an InputContext that has multiple partitions")
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def dagster_instance(tmp_path_factory: TempPathFactory) -> DagsterInstance:
def local_ray_address() -> Iterator[str]:
import ray

context = ray.init()
context = ray.init(runtime_env={"env_vars": {"RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING": "1"}})

yield "auto"

Expand Down
27 changes: 27 additions & 0 deletions tests/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
reconstructable,
)

from dagster_ray import RayIOManager
from dagster_ray.executor import ray_executor


Expand Down Expand Up @@ -82,6 +83,32 @@ def test_ray_executor(local_ray_address: str, dagster_instance: DagsterInstance)
assert result.success, result.get_step_failure_events()[0].event_specific_data


ray_io_manager = RayIOManager()


@job(executor_def=ray_executor, resource_defs={"io_manager": ray_io_manager})
def my_job_with_ray_io_manager():
return_two_result = return_two()
return_one_result = return_one()
sum_one_and_two(return_one_result, return_two_result)


def test_ray_executor_with_ray_io_manager(local_ray_address: str, dagster_instance: DagsterInstance):
result = execute_job(
job=reconstructable(my_job_with_ray_io_manager),
instance=dagster_instance,
run_config={
"execution": {
"config": {
"ray": {"address": local_ray_address},
}
}
},
)

assert result.success, result.get_step_failure_events()[0].event_specific_data


def test_ray_executor_local_failing(local_ray_address: str, dagster_instance: DagsterInstance):
result = execute_job(
job=reconstructable(my_failing_job),
Expand Down
58 changes: 58 additions & 0 deletions tests/test_io_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Dict

from dagster import AssetExecutionContext, StaticPartitionsDefinition, asset, materialize

from dagster_ray import RayIOManager


def test_ray_io_manager():
@asset
def upstream():
return 1

@asset
def downstream(upstream) -> None:
assert upstream == 1

materialize(
[upstream, downstream],
resources={"io_manager": RayIOManager()},
)


def test_ray_io_manager_partitioned():
partitions_def = StaticPartitionsDefinition(partition_keys=["A", "B", "C"])

@asset(partitions_def=partitions_def)
def upsteram_partitioned(context: AssetExecutionContext) -> str:
return context.partition_key.lower()

@asset(partitions_def=partitions_def)
def downstream_partitioned(context: AssetExecutionContext, upsteram_partitioned: str) -> None:
assert upsteram_partitioned == context.partition_key.lower()

for partition_key in ["A", "B", "C"]:
materialize(
[upsteram_partitioned, downstream_partitioned],
resources={"io_manager": RayIOManager()},
partition_key=partition_key,
)


def test_ray_io_manager_partition_mapping():
partitions_def = StaticPartitionsDefinition(partition_keys=["A", "B", "C"])

@asset(partitions_def=partitions_def)
def upsteram_partitioned(context: AssetExecutionContext) -> str:
return context.partition_key.lower()

@asset
def downstream_non_partitioned(upsteram_partitioned: Dict[str, str]) -> None:
assert upsteram_partitioned == {"A": "a", "B": "b", "C": "c"}

for partition_key in ["A", "B", "C"]:
materialize([upsteram_partitioned], resources={"io_manager": RayIOManager()}, partition_key=partition_key)

materialize(
[upsteram_partitioned.to_source_asset(), downstream_non_partitioned], resources={"io_manager": RayIOManager()}
)

0 comments on commit db03609

Please sign in to comment.