From 3d4f6507a9626ff8326bef51c03ca88698a1a654 Mon Sep 17 00:00:00 2001 From: Daniel Gafni Date: Tue, 17 Dec 2024 16:51:56 +0000 Subject: [PATCH] reverting framework-level executor changes --- .../step_delegating_executor.py | 53 ++++++++++--------- .../step_delegating/step_handler/base.py | 13 ++--- .../dagster-aws/dagster_aws/ecs/executor.py | 40 ++++++++++---- .../executor_tests/test_executor.py | 2 +- .../dagster_docker/docker_executor.py | 10 ++-- .../dagster-k8s/dagster_k8s/executor.py | 18 +++---- 6 files changed, 74 insertions(+), 62 deletions(-) diff --git a/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py b/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py index 9fad7f5a620f6..263f20eb1cc72 100644 --- a/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py +++ b/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py @@ -178,7 +178,6 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut instance_concurrency_context=instance_concurrency_context, ) as active_execution: running_steps: Dict[str, ExecutionStep] = {} - step_worker_handles: Dict[str, Optional[str]] = {} if plan_context.resume_from_failure: DagsterEvent.engine_event( @@ -212,8 +211,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut try: health_check = self._step_handler.check_step_health( - step_handler_context, - step_worker_handle=None, + step_handler_context ) except Exception: # For now we assume that an exception indicates that the step should be resumed. @@ -239,14 +237,15 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut if should_retry_step: # health check failed, launch the step - self._step_handler.launch_step( - self._get_step_handler_context( - plan_context, [step], active_execution + list( + self._step_handler.launch_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ) ) ) running_steps[step.key] = step - step_worker_handles[step.key] = None last_check_step_health_time = get_current_datetime() @@ -263,12 +262,13 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut "Executor received termination signal, forwarding to steps", EngineEventData.interrupted(list(running_steps.keys())), ) - for step_key, step in running_steps.items(): - self._step_handler.terminate_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ), - step_worker_handle=step_worker_handles[step_key], + for step in running_steps.values(): + list( + self._step_handler.terminate_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ) + ) ) else: DagsterEvent.engine_event( @@ -311,7 +311,6 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut ): assert isinstance(dagster_event.step_key, str) del running_steps[dagster_event.step_key] - del step_worker_handles[dagster_event.step_key] if not dagster_event.is_step_up_for_retry: active_execution.verify_complete( @@ -326,15 +325,14 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut curr_time - last_check_step_health_time ).total_seconds() >= self._check_step_health_interval_seconds: last_check_step_health_time = curr_time - for step_key, step in running_steps.items(): + for step in running_steps.values(): step_context = plan_context.for_step(step) try: health_check_result = self._step_handler.check_step_health( self._get_step_handler_context( plan_context, [step], active_execution - ), - step_worker_handle=step_worker_handles[step_key], + ) ) if not health_check_result.is_healthy: health_check_error = SerializableErrorInfo( @@ -376,9 +374,11 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut for step in active_execution.get_steps_to_execute(max_steps_to_run): running_steps[step.key] = step - step_worker_handles[step.key] = self._step_handler.launch_step( - self._get_step_handler_context( - plan_context, [step], active_execution + list( + self._step_handler.launch_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ) ) ) @@ -398,11 +398,12 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut error=serializable_error, ), ) - for step_key, step in running_steps.items(): - self._step_handler.terminate_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ), - step_worker_handle=step_worker_handles[step_key], + for step in running_steps.values(): + list( + self._step_handler.terminate_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ) + ) ) raise diff --git a/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py b/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py index ed12a386869ea..a40fc76408e25 100644 --- a/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py +++ b/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod -from typing import Mapping, NamedTuple, Optional, Sequence +from typing import Iterator, Mapping, NamedTuple, Optional, Sequence from dagster import ( DagsterInstance, _check as check, ) +from dagster._core.events import DagsterEvent from dagster._core.execution.context.system import IStepContext, PlanOrchestrationContext from dagster._core.execution.plan.step import ExecutionStep from dagster._core.storage.dagster_run import DagsterRun @@ -82,17 +83,13 @@ def name(self) -> str: pass @abstractmethod - def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]: + def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: pass @abstractmethod - def check_step_health( - self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] - ) -> CheckStepHealthResult: + def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: pass @abstractmethod - def terminate_step( - self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] - ) -> None: + def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: pass diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py index 8911e9009e66a..dc37850f9b46f 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py @@ -1,6 +1,6 @@ import json import os -from typing import Any, Dict, List, Mapping, Optional, Sequence, cast +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, cast import boto3 from dagster import ( @@ -159,6 +159,11 @@ def __init__( "taskDefinition": current_task["taskDefinitionArn"], } + # TODO: change launch_step to return task ARN + # this will be a breaking change so we need to wait fore a minor release + # to do this + self._launched_tasks = {} + def _get_run_task_kwargs( self, run: DagsterRun, @@ -187,6 +192,7 @@ def _get_run_task_kwargs( for k, v in step_handler_context.dagster_run.dagster_execution_info.items() ], {"key": "dagster/step-key", "value": step_key}, + {"key": "dagster/step-id", "value": self._get_step_id(step_handler_context)}, ] if run.remote_job_origin: @@ -263,6 +269,19 @@ def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Dict[str, Any]: return overrides + def _get_step_id(self, step_handler_context: StepHandlerContext): + """Step ID is used to identify the ECS task in the ECS cluster. It is unique to specific step being executed and takes into account op-level retries. It's used as a workardoung to avoid having to return task ARN from launch_step.""" + step_key = self._get_step_key(step_handler_context) + + if step_handler_context.execute_step_args.known_state: + retry_count = step_handler_context.execute_step_args.known_state.get_retry_state().get_attempt_count( + step_key + ) + else: + retry_count = 0 + + return "%s-%d" % (step_key, retry_count) + def _get_step_key(self, step_handler_context: StepHandlerContext) -> str: step_keys_to_execute = cast( List[str], step_handler_context.execute_step_args.step_keys_to_execute @@ -281,7 +300,7 @@ def _get_container_context( def _run_task(self, **run_task_kwargs): return run_ecs_task(self.ecs, run_task_kwargs) - def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]: + def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: step_key = self._get_step_key(step_handler_context) step_tags = step_handler_context.step_tags[step_key] @@ -312,7 +331,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str] ), ) - DagsterEvent.step_worker_starting( + yield DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message=f'Executing step "{step_key}" in ECS task.', metadata={ @@ -320,14 +339,14 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str] }, ) - return task["taskArn"] + step_id = self._get_step_id(step_handler_context) + + self._launched_tasks[step_id] = task["taskArn"] - def check_step_health( - self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] - ) -> CheckStepHealthResult: + def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: step_key = self._get_step_key(step_handler_context) - task_arn = step_worker_handle + task_arn = self._launched_tasks.get(self._get_step_id(step_handler_context)) cluster_arn = self._cluster_arn tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks") @@ -359,9 +378,10 @@ def check_step_health( return CheckStepHealthResult.healthy() def terminate_step( - self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + self, + step_handler_context: StepHandlerContext, ) -> None: - task_arn = step_worker_handle + task_arn = self._launched_tasks.pop(self._get_step_id(step_handler_context)) cluster_arn = self._cluster_arn step_key = self._get_step_key(step_handler_context) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py index df13f1bed8ae6..ee20ea57a2e90 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py @@ -200,7 +200,7 @@ def test_executor_launch(instance_cm: Callable[..., ContextManager[DagsterInstan return_value={"tasks": [{"taskArn": "arn:123"}]} ) - executor._step_handler.launch_step(step_handler_context) # noqa: SLF001 + next(iter(executor._step_handler.launch_step(step_handler_context))) # noqa: SLF001 run_task_kwargs = executor._step_handler.ecs.run_task.call_args[1] # noqa: SLF001 diff --git a/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py b/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py index 28bcd061b9363..ab60303b0f8a7 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Iterator, List, Optional, cast import dagster._check as check import docker @@ -222,7 +222,7 @@ def _create_step_container( **container_kwargs, ) - def launch_step(self, step_handler_context: StepHandlerContext) -> None: + def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: container_context = self._get_docker_container_context(step_handler_context) client = self._get_client(container_context) @@ -251,7 +251,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> None: assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" step_key = step_keys_to_execute[0] - DagsterEvent.step_worker_starting( + yield DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message="Launching step in Docker container.", metadata={ @@ -294,7 +294,7 @@ def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckSt reason=f"Container status is {container.status}. Return code is {ret_code}." ) - def terminate_step(self, step_handler_context: StepHandlerContext) -> None: + def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: container_context = self._get_docker_container_context(step_handler_context) step_keys_to_execute = check.not_none( @@ -307,7 +307,7 @@ def terminate_step(self, step_handler_context: StepHandlerContext) -> None: container_name = self._get_container_name(step_handler_context) - DagsterEvent.engine_event( + yield DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Stopping Docker container {container_name} for step.", event_specific_data=EngineEventData(), diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index 1fe9ce5398a72..daffa21caadb1 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -1,4 +1,4 @@ -from typing import List, Optional, cast +from typing import Iterator, List, Optional, cast import kubernetes.config from dagster import ( @@ -261,7 +261,7 @@ def _get_k8s_step_job_name(self, step_handler_context: StepHandlerContext): return "dagster-step-%s" % (name_key) - def launch_step(self, step_handler_context: StepHandlerContext) -> None: + def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) @@ -313,7 +313,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> None: ], ) - DagsterEvent.step_worker_starting( + yield DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message=f'Executing step "{step_key}" in Kubernetes job {job_name}.', metadata={ @@ -324,11 +324,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> None: namespace = check.not_none(container_context.namespace) self._api_client.create_namespaced_job_with_retries(body=job, namespace=namespace) - return None - - def check_step_health( - self, step_handler_context: StepHandlerContext, step_identifier: Optional[str] - ) -> CheckStepHealthResult: + def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) @@ -350,15 +346,13 @@ def check_step_health( return CheckStepHealthResult.healthy() - def terminate_step( - self, step_handler_context: StepHandlerContext, step_identifier: str - ) -> None: + def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) container_context = self._get_container_context(step_handler_context) - DagsterEvent.engine_event( + yield DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Deleting Kubernetes job {job_name} for step", event_specific_data=EngineEventData(),