-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
734d723
commit db03609
Showing
6 changed files
with
296 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from dagster_ray._base.resources import BaseRayResource | ||
from dagster_ray.executor import ray_executor | ||
from dagster_ray.io_manager import RayIOManager | ||
from dagster_ray.run_launcher import RayRunLauncher | ||
|
||
RayResource = BaseRayResource | ||
|
||
|
||
__all__ = ["RayResource", "RayRunLauncher", "ray_executor"] | ||
__all__ = ["RayResource", "RayRunLauncher", "RayIOManager", "ray_executor"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
from typing import TYPE_CHECKING, Dict, Optional, Union | ||
|
||
from dagster import ConfigurableIOManager, ConfigurableResource, InputContext, OutputContext | ||
|
||
DAGSTER_RAY_OBJECT_MAP_NAME = "DagsterRayObjectMap" | ||
DAGSTER_RAY_NAMESPACE = "dagster-ray" | ||
|
||
# we need to create a global Ray actor which will store all the refs to all objcets | ||
|
||
if TYPE_CHECKING: | ||
import ray | ||
|
||
|
||
class RayObjectMap: | ||
# TODO: implement some eventual cleanup mechanism | ||
# idea: save creation timestamp and periodically check for old refs | ||
# or add some integration with the RunLauncher/Executor | ||
def __init__(self): | ||
self._object_map: Dict[str, "ray.ObjectRef"] = {} | ||
|
||
def set(self, key: str, ref: "ray.ObjectRef"): | ||
self._object_map[key] = ref | ||
|
||
def get(self, key: str) -> Optional["ray.ObjectRef"]: | ||
return self._object_map.get(key) | ||
|
||
def delete(self, key: str): | ||
if key in self._object_map: | ||
del self._object_map[key] | ||
|
||
def keys(self): | ||
return self._object_map.keys() | ||
|
||
def ping(self): | ||
return "pong" | ||
|
||
@staticmethod | ||
def get_or_create(): | ||
import ray | ||
|
||
actor = ( | ||
ray.remote(RayObjectMap) | ||
.options( # type: ignore | ||
name=DAGSTER_RAY_OBJECT_MAP_NAME, | ||
namespace=DAGSTER_RAY_NAMESPACE, | ||
get_if_exists=True, | ||
lifetime="detached", | ||
# max_restarts=-1, | ||
max_concurrency=1000, # TODO: make this configurable, | ||
runtime_env={"RAY_ENABLE_RECORD_ACTOR_TASK_LOGGING": "1"}, | ||
) | ||
.remote() | ||
) | ||
|
||
# make sure the actor is created | ||
ray.get(actor.ping.remote()) | ||
|
||
return actor | ||
|
||
|
||
class RayIOManager(ConfigurableIOManager, ConfigurableResource): | ||
address: Optional[str] = None | ||
|
||
def handle_output(self, context: OutputContext, obj): | ||
import ray | ||
|
||
# if self.address: # TODO: should this really be done here? | ||
# ray.init(self.address, ignore_reinit_error=True) | ||
|
||
object_map = RayObjectMap.get_or_create() | ||
|
||
storage_key = self._get_single_key(context) | ||
|
||
# TODO: understand if Ray will automatically move the object from dying nodes | ||
# what if not? | ||
|
||
ref = ray.put(obj, _owner=object_map) | ||
|
||
object_map.set.remote(storage_key, ref) | ||
|
||
context.log.debug(f"Stored object with key {storage_key} as {ref}") | ||
|
||
def load_input(self, context: InputContext): | ||
import ray | ||
|
||
# if self.address: # TODO: should this really be done here? | ||
# ray.init(self.address, ignore_reinit_error=True) | ||
|
||
object_map = RayObjectMap.get_or_create() | ||
|
||
if context.has_asset_partitions and len(context.asset_partition_keys) > 1: | ||
# load multiple partitions as once | ||
# first, get the refs | ||
|
||
storage_keys = self._get_multiple_keys(context) | ||
refs = [object_map.get.remote(key) for key in storage_keys.values()] | ||
values = ray.get(refs) | ||
return {partition_key: value for partition_key, value in zip(storage_keys.keys(), values)} | ||
|
||
else: | ||
storage_key = self._get_single_key(context) | ||
|
||
context.log.debug(f"Loading object with key {storage_key}") | ||
|
||
ref = object_map.get.remote(storage_key) | ||
|
||
assert ref is not None, f"Object with key {storage_key} not found in RayObjectMap" | ||
|
||
return ray.get(ref) | ||
|
||
def _get_single_key(self, context: Union[InputContext, OutputContext]) -> str: | ||
identifier = context.get_identifier() if not context.has_asset_key else context.get_asset_identifier() | ||
return "/".join(identifier) | ||
|
||
def _get_multiple_keys(self, context: InputContext) -> Dict[str, str]: | ||
if context.has_asset_key: | ||
asset_path = list(context.asset_key.path) | ||
|
||
return { | ||
partition_key: "/".join(asset_path + [partition_key]) for partition_key in context.asset_partition_keys | ||
} | ||
else: | ||
raise RuntimeError("This method can only be called with an InputContext that has multiple partitions") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from typing import Dict | ||
|
||
from dagster import AssetExecutionContext, StaticPartitionsDefinition, asset, materialize | ||
|
||
from dagster_ray import RayIOManager | ||
|
||
|
||
def test_ray_io_manager(): | ||
@asset | ||
def upstream(): | ||
return 1 | ||
|
||
@asset | ||
def downstream(upstream) -> None: | ||
assert upstream == 1 | ||
|
||
materialize( | ||
[upstream, downstream], | ||
resources={"io_manager": RayIOManager()}, | ||
) | ||
|
||
|
||
def test_ray_io_manager_partitioned(): | ||
partitions_def = StaticPartitionsDefinition(partition_keys=["A", "B", "C"]) | ||
|
||
@asset(partitions_def=partitions_def) | ||
def upsteram_partitioned(context: AssetExecutionContext) -> str: | ||
return context.partition_key.lower() | ||
|
||
@asset(partitions_def=partitions_def) | ||
def downstream_partitioned(context: AssetExecutionContext, upsteram_partitioned: str) -> None: | ||
assert upsteram_partitioned == context.partition_key.lower() | ||
|
||
for partition_key in ["A", "B", "C"]: | ||
materialize( | ||
[upsteram_partitioned, downstream_partitioned], | ||
resources={"io_manager": RayIOManager()}, | ||
partition_key=partition_key, | ||
) | ||
|
||
|
||
def test_ray_io_manager_partition_mapping(): | ||
partitions_def = StaticPartitionsDefinition(partition_keys=["A", "B", "C"]) | ||
|
||
@asset(partitions_def=partitions_def) | ||
def upsteram_partitioned(context: AssetExecutionContext) -> str: | ||
return context.partition_key.lower() | ||
|
||
@asset | ||
def downstream_non_partitioned(upsteram_partitioned: Dict[str, str]) -> None: | ||
assert upsteram_partitioned == {"A": "a", "B": "b", "C": "c"} | ||
|
||
for partition_key in ["A", "B", "C"]: | ||
materialize([upsteram_partitioned], resources={"io_manager": RayIOManager()}, partition_key=partition_key) | ||
|
||
materialize( | ||
[upsteram_partitioned.to_source_asset(), downstream_non_partitioned], resources={"io_manager": RayIOManager()} | ||
) |