Skip to content

Commit

Permalink
reverting framework-level executor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Dec 17, 2024
1 parent e77addc commit 15a6977
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
instance_concurrency_context=instance_concurrency_context,
) as active_execution:
running_steps: Dict[str, ExecutionStep] = {}
step_worker_handles: Dict[str, Optional[str]] = {}

if plan_context.resume_from_failure:
DagsterEvent.engine_event(
Expand Down Expand Up @@ -212,8 +211,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut

try:
health_check = self._step_handler.check_step_health(
step_handler_context,
step_worker_handle=None,
step_handler_context
)
except Exception:
# For now we assume that an exception indicates that the step should be resumed.
Expand All @@ -239,14 +237,15 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut

if should_retry_step:
# health check failed, launch the step
self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
list(
self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)

running_steps[step.key] = step
step_worker_handles[step.key] = None

last_check_step_health_time = get_current_datetime()

Expand All @@ -263,12 +262,13 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
"Executor received termination signal, forwarding to steps",
EngineEventData.interrupted(list(running_steps.keys())),
)
for step_key, step in running_steps.items():
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
),
step_worker_handle=step_worker_handles[step_key],
for step in running_steps.values():
list(
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)
else:
DagsterEvent.engine_event(
Expand Down Expand Up @@ -311,7 +311,6 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
):
assert isinstance(dagster_event.step_key, str)
del running_steps[dagster_event.step_key]
del step_worker_handles[dagster_event.step_key]

if not dagster_event.is_step_up_for_retry:
active_execution.verify_complete(
Expand All @@ -326,15 +325,14 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
curr_time - last_check_step_health_time
).total_seconds() >= self._check_step_health_interval_seconds:
last_check_step_health_time = curr_time
for step_key, step in running_steps.items():
for step in running_steps.values():
step_context = plan_context.for_step(step)

try:
health_check_result = self._step_handler.check_step_health(
self._get_step_handler_context(
plan_context, [step], active_execution
),
step_worker_handle=step_worker_handles[step_key],
)
)
if not health_check_result.is_healthy:
health_check_error = SerializableErrorInfo(
Expand Down Expand Up @@ -376,9 +374,11 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut

for step in active_execution.get_steps_to_execute(max_steps_to_run):
running_steps[step.key] = step
step_worker_handles[step.key] = self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
list(
self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)

Expand All @@ -398,11 +398,12 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
error=serializable_error,
),
)
for step_key, step in running_steps.items():
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
),
step_worker_handle=step_worker_handles[step_key],
for step in running_steps.values():
list(
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)
raise
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import Mapping, NamedTuple, Optional, Sequence
from typing import Iterator, Mapping, NamedTuple, Optional, Sequence

from dagster import (
DagsterInstance,
_check as check,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.system import IStepContext, PlanOrchestrationContext
from dagster._core.execution.plan.step import ExecutionStep
from dagster._core.storage.dagster_run import DagsterRun
Expand Down Expand Up @@ -82,17 +83,13 @@ def name(self) -> str:
pass

@abstractmethod
def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]:
def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
pass

@abstractmethod
def check_step_health(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
) -> CheckStepHealthResult:
def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult:
pass

@abstractmethod
def terminate_step(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
) -> None:
def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
pass
108 changes: 92 additions & 16 deletions python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, cast

import boto3
from dagster import (
Expand Down Expand Up @@ -77,7 +77,41 @@
requirements=multiple_process_executor_requirements(),
)
def ecs_executor(init_context: InitExecutorContext) -> Executor:
"""Executor which launches steps as ECS tasks."""
"""Executor which launches steps as ECS tasks.
To use the `ecs_executor`, set it as the `executor_def` when defining a job:
.. literalinclude:: ../../../../../../python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/run_launcher_tests/executor_tests/test_example_executor_mode_def.py
:start-after: start_marker
:end-before: end_marker
:language: python
Then you can configure the executor with run config as follows:
.. code-block:: YAML
execution:
config:
cpu: 1024
memory: 2048
ephemeral_storage: 10
task_overrides:
containerOverrides:
- name: run
environment:
- name: MY_ENV_VAR
value: "my_value"
`max_concurrent` limits the number of ECS tasks that will execute concurrently for one run. By default
there is no limit- it will maximally parallel as allowed by the DAG. Note that this is not a
global limit.
Configuration set on the ECS tasks created by the `ECSRunLauncher` will also be
set on the tasks created by the `ecs_executor`.
Configuration set using `tags` on a `@job` will only apply to the `run` level. For configuration
to apply at each `step` it must be set using `tags` for each `@op`.
"""
run_launcher = init_context.instance.run_launcher

check.invariant(
Expand Down Expand Up @@ -151,14 +185,17 @@ def __init__(
self._cluster_arn = current_task["clusterArn"]
self._task_definition_arn = current_task["taskDefinitionArn"]

# note: more kwargs will be pulled from the run launcher
# during launch_step
self._run_task_kwargs = {
"taskDefinition": current_task["taskDefinitionArn"],
**run_launcher_kwargs,
**run_task_kwargs,
"taskDefinition": current_task["taskDefinitionArn"],
}

# TODO: change launch_step to return task ARN
# this will be a breaking change so we need to wait for a minor release
# to do this
self._launched_tasks = {}

def _get_run_task_kwargs(
self,
run: DagsterRun,
Expand All @@ -176,6 +213,12 @@ def _get_run_task_kwargs(

run_task_kwargs = self._run_task_kwargs

run_task_kwargs["tags"] = [
*run_task_kwargs.get("tags", []),
# add RunLauncher tags
*run_launcher.build_ecs_tags_for_run_task(run, container_context),
]

kwargs_from_tags = step_tags.get("ecs/run_task_kwargs")
if kwargs_from_tags:
run_task_kwargs = {**run_task_kwargs, **json.loads(kwargs_from_tags)}
Expand All @@ -187,6 +230,7 @@ def _get_run_task_kwargs(
for k, v in step_handler_context.dagster_run.dagster_execution_info.items()
],
{"key": "dagster/step-key", "value": step_key},
{"key": "dagster/step-id", "value": self._get_step_id(step_handler_context)},
]

if run.remote_job_origin:
Expand Down Expand Up @@ -263,6 +307,22 @@ def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Dict[str, Any]:

return overrides

def _get_step_id(self, step_handler_context: StepHandlerContext):
"""Step ID is used to identify the ECS task in the ECS cluster.
It is unique to specific step being executed and takes into account op-level retries.
It's used as a workaround to avoid having to return task ARN from launch_step.
"""
step_key = self._get_step_key(step_handler_context)

if step_handler_context.execute_step_args.known_state:
retry_count = step_handler_context.execute_step_args.known_state.get_retry_state().get_attempt_count(
step_key
)
else:
retry_count = 0

return "%s-%d" % (step_key, retry_count)

def _get_step_key(self, step_handler_context: StepHandlerContext) -> str:
step_keys_to_execute = cast(
List[str], step_handler_context.execute_step_args.step_keys_to_execute
Expand All @@ -281,7 +341,7 @@ def _get_container_context(
def _run_task(self, **run_task_kwargs):
return run_ecs_task(self.ecs, run_task_kwargs)

def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]:
def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
step_key = self._get_step_key(step_handler_context)

step_tags = step_handler_context.step_tags[step_key]
Expand Down Expand Up @@ -312,22 +372,29 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]
),
)

DagsterEvent.step_worker_starting(
yield DagsterEvent.step_worker_starting(
step_handler_context.get_step_context(step_key),
message=f'Executing step "{step_key}" in ECS task.',
metadata={
"Task ARN": MetadataValue.text(task["taskArn"]),
},
)

return task["taskArn"]
step_id = self._get_step_id(step_handler_context)

self._launched_tasks[step_id] = task["taskArn"]

def check_step_health(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
) -> CheckStepHealthResult:
def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult:
step_key = self._get_step_key(step_handler_context)
step_id = self._get_step_id(step_handler_context)

try:
task_arn = self._launched_tasks[step_id]
except KeyError:
return CheckStepHealthResult.unhealthy(
reason=f"Task ARN for step {step_key} could not be found in executor's task map. This is likely a bug."
)

task_arn = step_worker_handle
cluster_arn = self._cluster_arn

tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks")
Expand Down Expand Up @@ -359,15 +426,24 @@ def check_step_health(
return CheckStepHealthResult.healthy()

def terminate_step(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
self,
step_handler_context: StepHandlerContext,
) -> None:
task_arn = step_worker_handle
cluster_arn = self._cluster_arn
step_id = self._get_step_id(step_handler_context)
step_key = self._get_step_key(step_handler_context)

try:
task_arn = self._launched_tasks[step_id]
except KeyError:
raise DagsterInvariantViolationError(
f"Task ARN for step {step_key} could not be found in executor's task map. This is likely a bug."
)

cluster_arn = self._cluster_arn

DagsterEvent.engine_event(
step_handler_context.get_step_context(step_key),
message=f"Deleting task {task_arn} for step",
message=f"Stopping task {task_arn} for step",
event_specific_data=EngineEventData(),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def launch_run(self, context: LaunchRunContext) -> None:
command = self._get_command_args(args, context)
image = self.get_image_for_run(run)

run_task_kwargs = self.get_run_task_kwargs(run, image, container_context)
run_task_kwargs = self._run_task_kwargs(run, image, container_context)

# Set cpu or memory overrides
# https://docs.aws.amazon.com/AmazonECS/latest/developerguide/task-cpu-memory-error.html
Expand Down Expand Up @@ -627,7 +627,7 @@ def _get_run_task_definition_family(self, run: DagsterRun) -> str:
def get_container_name(self, container_context: EcsContainerContext) -> str:
return container_context.container_name or self.container_name

def get_run_task_kwargs(
def _run_task_kwargs(
self, run: DagsterRun, image: Optional[str], container_context: EcsContainerContext
) -> Dict[str, Any]:
"""Return a dictionary of args to launch the ECS task, registering a new task
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# ruff: isort: skip_file
# fmt: off
# start_marker
from dagster_aws.ecs import ecs_executor

from dagster import job, op


@op(
tags={"ecs/cpu": "256", "ecs/memory": "512"},
)
def ecs_op():
pass


@job(executor_def=ecs_executor)
def ecs_job():
ecs_op()


# end_marker
# fmt: on


def test_mode():
assert ecs_job
Loading

0 comments on commit 15a6977

Please sign in to comment.