Skip to content

Commit

Permalink
reverting framework-level executor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Dec 17, 2024
1 parent e77addc commit 3d4f650
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
instance_concurrency_context=instance_concurrency_context,
) as active_execution:
running_steps: Dict[str, ExecutionStep] = {}
step_worker_handles: Dict[str, Optional[str]] = {}

if plan_context.resume_from_failure:
DagsterEvent.engine_event(
Expand Down Expand Up @@ -212,8 +211,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut

try:
health_check = self._step_handler.check_step_health(
step_handler_context,
step_worker_handle=None,
step_handler_context
)
except Exception:
# For now we assume that an exception indicates that the step should be resumed.
Expand All @@ -239,14 +237,15 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut

if should_retry_step:
# health check failed, launch the step
self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
list(
self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)

running_steps[step.key] = step
step_worker_handles[step.key] = None

last_check_step_health_time = get_current_datetime()

Expand All @@ -263,12 +262,13 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
"Executor received termination signal, forwarding to steps",
EngineEventData.interrupted(list(running_steps.keys())),
)
for step_key, step in running_steps.items():
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
),
step_worker_handle=step_worker_handles[step_key],
for step in running_steps.values():
list(
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)
else:
DagsterEvent.engine_event(
Expand Down Expand Up @@ -311,7 +311,6 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
):
assert isinstance(dagster_event.step_key, str)
del running_steps[dagster_event.step_key]
del step_worker_handles[dagster_event.step_key]

if not dagster_event.is_step_up_for_retry:
active_execution.verify_complete(
Expand All @@ -326,15 +325,14 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
curr_time - last_check_step_health_time
).total_seconds() >= self._check_step_health_interval_seconds:
last_check_step_health_time = curr_time
for step_key, step in running_steps.items():
for step in running_steps.values():
step_context = plan_context.for_step(step)

try:
health_check_result = self._step_handler.check_step_health(
self._get_step_handler_context(
plan_context, [step], active_execution
),
step_worker_handle=step_worker_handles[step_key],
)
)
if not health_check_result.is_healthy:
health_check_error = SerializableErrorInfo(
Expand Down Expand Up @@ -376,9 +374,11 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut

for step in active_execution.get_steps_to_execute(max_steps_to_run):
running_steps[step.key] = step
step_worker_handles[step.key] = self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
list(
self._step_handler.launch_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)

Expand All @@ -398,11 +398,12 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut
error=serializable_error,
),
)
for step_key, step in running_steps.items():
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
),
step_worker_handle=step_worker_handles[step_key],
for step in running_steps.values():
list(
self._step_handler.terminate_step(
self._get_step_handler_context(
plan_context, [step], active_execution
)
)
)
raise
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import Mapping, NamedTuple, Optional, Sequence
from typing import Iterator, Mapping, NamedTuple, Optional, Sequence

from dagster import (
DagsterInstance,
_check as check,
)
from dagster._core.events import DagsterEvent
from dagster._core.execution.context.system import IStepContext, PlanOrchestrationContext
from dagster._core.execution.plan.step import ExecutionStep
from dagster._core.storage.dagster_run import DagsterRun
Expand Down Expand Up @@ -82,17 +83,13 @@ def name(self) -> str:
pass

@abstractmethod
def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]:
def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
pass

@abstractmethod
def check_step_health(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
) -> CheckStepHealthResult:
def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult:
pass

@abstractmethod
def terminate_step(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
) -> None:
def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
pass
40 changes: 30 additions & 10 deletions python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
import os
from typing import Any, Dict, List, Mapping, Optional, Sequence, cast
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, cast

import boto3
from dagster import (
Expand Down Expand Up @@ -159,6 +159,11 @@ def __init__(
"taskDefinition": current_task["taskDefinitionArn"],
}

# TODO: change launch_step to return task ARN
# this will be a breaking change so we need to wait fore a minor release
# to do this
self._launched_tasks = {}

def _get_run_task_kwargs(
self,
run: DagsterRun,
Expand Down Expand Up @@ -187,6 +192,7 @@ def _get_run_task_kwargs(
for k, v in step_handler_context.dagster_run.dagster_execution_info.items()
],
{"key": "dagster/step-key", "value": step_key},
{"key": "dagster/step-id", "value": self._get_step_id(step_handler_context)},
]

if run.remote_job_origin:
Expand Down Expand Up @@ -263,6 +269,19 @@ def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Dict[str, Any]:

return overrides

def _get_step_id(self, step_handler_context: StepHandlerContext):
"""Step ID is used to identify the ECS task in the ECS cluster. It is unique to specific step being executed and takes into account op-level retries. It's used as a workardoung to avoid having to return task ARN from launch_step."""
step_key = self._get_step_key(step_handler_context)

if step_handler_context.execute_step_args.known_state:
retry_count = step_handler_context.execute_step_args.known_state.get_retry_state().get_attempt_count(
step_key
)
else:
retry_count = 0

return "%s-%d" % (step_key, retry_count)

def _get_step_key(self, step_handler_context: StepHandlerContext) -> str:
step_keys_to_execute = cast(
List[str], step_handler_context.execute_step_args.step_keys_to_execute
Expand All @@ -281,7 +300,7 @@ def _get_container_context(
def _run_task(self, **run_task_kwargs):
return run_ecs_task(self.ecs, run_task_kwargs)

def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]:
def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
step_key = self._get_step_key(step_handler_context)

step_tags = step_handler_context.step_tags[step_key]
Expand Down Expand Up @@ -312,22 +331,22 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]
),
)

DagsterEvent.step_worker_starting(
yield DagsterEvent.step_worker_starting(
step_handler_context.get_step_context(step_key),
message=f'Executing step "{step_key}" in ECS task.',
metadata={
"Task ARN": MetadataValue.text(task["taskArn"]),
},
)

return task["taskArn"]
step_id = self._get_step_id(step_handler_context)

self._launched_tasks[step_id] = task["taskArn"]

def check_step_health(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
) -> CheckStepHealthResult:
def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult:
step_key = self._get_step_key(step_handler_context)

task_arn = step_worker_handle
task_arn = self._launched_tasks.get(self._get_step_id(step_handler_context))
cluster_arn = self._cluster_arn

tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks")
Expand Down Expand Up @@ -359,9 +378,10 @@ def check_step_health(
return CheckStepHealthResult.healthy()

def terminate_step(
self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str]
self,
step_handler_context: StepHandlerContext,
) -> None:
task_arn = step_worker_handle
task_arn = self._launched_tasks.pop(self._get_step_id(step_handler_context))
cluster_arn = self._cluster_arn
step_key = self._get_step_key(step_handler_context)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def test_executor_launch(instance_cm: Callable[..., ContextManager[DagsterInstan
return_value={"tasks": [{"taskArn": "arn:123"}]}
)

executor._step_handler.launch_step(step_handler_context) # noqa: SLF001
next(iter(executor._step_handler.launch_step(step_handler_context))) # noqa: SLF001

run_task_kwargs = executor._step_handler.ecs.run_task.call_args[1] # noqa: SLF001

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, cast
from typing import Iterator, List, Optional, cast

import dagster._check as check
import docker
Expand Down Expand Up @@ -222,7 +222,7 @@ def _create_step_container(
**container_kwargs,
)

def launch_step(self, step_handler_context: StepHandlerContext) -> None:
def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
container_context = self._get_docker_container_context(step_handler_context)

client = self._get_client(container_context)
Expand Down Expand Up @@ -251,7 +251,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> None:
assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported"
step_key = step_keys_to_execute[0]

DagsterEvent.step_worker_starting(
yield DagsterEvent.step_worker_starting(
step_handler_context.get_step_context(step_key),
message="Launching step in Docker container.",
metadata={
Expand Down Expand Up @@ -294,7 +294,7 @@ def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckSt
reason=f"Container status is {container.status}. Return code is {ret_code}."
)

def terminate_step(self, step_handler_context: StepHandlerContext) -> None:
def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
container_context = self._get_docker_container_context(step_handler_context)

step_keys_to_execute = check.not_none(
Expand All @@ -307,7 +307,7 @@ def terminate_step(self, step_handler_context: StepHandlerContext) -> None:

container_name = self._get_container_name(step_handler_context)

DagsterEvent.engine_event(
yield DagsterEvent.engine_event(
step_handler_context.get_step_context(step_key),
message=f"Stopping Docker container {container_name} for step.",
event_specific_data=EngineEventData(),
Expand Down
18 changes: 6 additions & 12 deletions python_modules/libraries/dagster-k8s/dagster_k8s/executor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, cast
from typing import Iterator, List, Optional, cast

import kubernetes.config
from dagster import (
Expand Down Expand Up @@ -261,7 +261,7 @@ def _get_k8s_step_job_name(self, step_handler_context: StepHandlerContext):

return "dagster-step-%s" % (name_key)

def launch_step(self, step_handler_context: StepHandlerContext) -> None:
def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
step_key = self._get_step_key(step_handler_context)

job_name = self._get_k8s_step_job_name(step_handler_context)
Expand Down Expand Up @@ -313,7 +313,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> None:
],
)

DagsterEvent.step_worker_starting(
yield DagsterEvent.step_worker_starting(
step_handler_context.get_step_context(step_key),
message=f'Executing step "{step_key}" in Kubernetes job {job_name}.',
metadata={
Expand All @@ -324,11 +324,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> None:
namespace = check.not_none(container_context.namespace)
self._api_client.create_namespaced_job_with_retries(body=job, namespace=namespace)

return None

def check_step_health(
self, step_handler_context: StepHandlerContext, step_identifier: Optional[str]
) -> CheckStepHealthResult:
def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult:
step_key = self._get_step_key(step_handler_context)

job_name = self._get_k8s_step_job_name(step_handler_context)
Expand All @@ -350,15 +346,13 @@ def check_step_health(

return CheckStepHealthResult.healthy()

def terminate_step(
self, step_handler_context: StepHandlerContext, step_identifier: str
) -> None:
def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]:
step_key = self._get_step_key(step_handler_context)

job_name = self._get_k8s_step_job_name(step_handler_context)
container_context = self._get_container_context(step_handler_context)

DagsterEvent.engine_event(
yield DagsterEvent.engine_event(
step_handler_context.get_step_context(step_key),
message=f"Deleting Kubernetes job {job_name} for step",
event_specific_data=EngineEventData(),
Expand Down

0 comments on commit 3d4f650

Please sign in to comment.