diff --git a/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py b/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py index 263f20eb1cc72..9fad7f5a620f6 100644 --- a/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py +++ b/python_modules/dagster/dagster/_core/executor/step_delegating/step_delegating_executor.py @@ -178,6 +178,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut instance_concurrency_context=instance_concurrency_context, ) as active_execution: running_steps: Dict[str, ExecutionStep] = {} + step_worker_handles: Dict[str, Optional[str]] = {} if plan_context.resume_from_failure: DagsterEvent.engine_event( @@ -211,7 +212,8 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut try: health_check = self._step_handler.check_step_health( - step_handler_context + step_handler_context, + step_worker_handle=None, ) except Exception: # For now we assume that an exception indicates that the step should be resumed. @@ -237,15 +239,14 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut if should_retry_step: # health check failed, launch the step - list( - self._step_handler.launch_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) + self._step_handler.launch_step( + self._get_step_handler_context( + plan_context, [step], active_execution ) ) running_steps[step.key] = step + step_worker_handles[step.key] = None last_check_step_health_time = get_current_datetime() @@ -262,13 +263,12 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut "Executor received termination signal, forwarding to steps", EngineEventData.interrupted(list(running_steps.keys())), ) - for step in running_steps.values(): - list( - self._step_handler.terminate_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) - ) + for step_key, step in running_steps.items(): + self._step_handler.terminate_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ), + step_worker_handle=step_worker_handles[step_key], ) else: DagsterEvent.engine_event( @@ -311,6 +311,7 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut ): assert isinstance(dagster_event.step_key, str) del running_steps[dagster_event.step_key] + del step_worker_handles[dagster_event.step_key] if not dagster_event.is_step_up_for_retry: active_execution.verify_complete( @@ -325,14 +326,15 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut curr_time - last_check_step_health_time ).total_seconds() >= self._check_step_health_interval_seconds: last_check_step_health_time = curr_time - for step in running_steps.values(): + for step_key, step in running_steps.items(): step_context = plan_context.for_step(step) try: health_check_result = self._step_handler.check_step_health( self._get_step_handler_context( plan_context, [step], active_execution - ) + ), + step_worker_handle=step_worker_handles[step_key], ) if not health_check_result.is_healthy: health_check_error = SerializableErrorInfo( @@ -374,11 +376,9 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut for step in active_execution.get_steps_to_execute(max_steps_to_run): running_steps[step.key] = step - list( - self._step_handler.launch_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) + step_worker_handles[step.key] = self._step_handler.launch_step( + self._get_step_handler_context( + plan_context, [step], active_execution ) ) @@ -398,12 +398,11 @@ def execute(self, plan_context: PlanOrchestrationContext, execution_plan: Execut error=serializable_error, ), ) - for step in running_steps.values(): - list( - self._step_handler.terminate_step( - self._get_step_handler_context( - plan_context, [step], active_execution - ) - ) + for step_key, step in running_steps.items(): + self._step_handler.terminate_step( + self._get_step_handler_context( + plan_context, [step], active_execution + ), + step_worker_handle=step_worker_handles[step_key], ) raise diff --git a/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py b/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py index a40fc76408e25..ed12a386869ea 100644 --- a/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py +++ b/python_modules/dagster/dagster/_core/executor/step_delegating/step_handler/base.py @@ -1,11 +1,10 @@ from abc import ABC, abstractmethod -from typing import Iterator, Mapping, NamedTuple, Optional, Sequence +from typing import Mapping, NamedTuple, Optional, Sequence from dagster import ( DagsterInstance, _check as check, ) -from dagster._core.events import DagsterEvent from dagster._core.execution.context.system import IStepContext, PlanOrchestrationContext from dagster._core.execution.plan.step import ExecutionStep from dagster._core.storage.dagster_run import DagsterRun @@ -83,13 +82,17 @@ def name(self) -> str: pass @abstractmethod - def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]: pass @abstractmethod - def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: + def check_step_health( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> CheckStepHealthResult: pass @abstractmethod - def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def terminate_step( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> None: pass diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py new file mode 100644 index 0000000000000..1a11d668c2ff7 --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py @@ -0,0 +1,304 @@ +import json +import os +from typing import Any, List, Mapping, Optional, cast + +import boto3 +from dagster import ( + Field, + IntSource, + Permissive, + _check as check, + executor, +) +from dagster._core.definitions.executor_definition import multiple_process_executor_requirements +from dagster._core.definitions.metadata import MetadataValue +from dagster._core.events import DagsterEvent, EngineEventData +from dagster._core.execution.retries import RetryMode, get_retries_config +from dagster._core.execution.tags import get_tag_concurrency_limits_config +from dagster._core.executor.base import Executor +from dagster._core.executor.init import InitExecutorContext +from dagster._core.executor.step_delegating import ( + CheckStepHealthResult, + StepDelegatingExecutor, + StepHandler, + StepHandlerContext, +) +from dagster._utils.backoff import backoff + +from dagster_aws.ecs.container_context import EcsContainerContext +from dagster_aws.ecs.launcher import STOPPED_STATUSES, EcsRunLauncher +from dagster_aws.ecs.tasks import get_current_ecs_task, get_current_ecs_task_metadata +from dagster_aws.ecs.utils import RetryableEcsException, run_ecs_task + +DEFAULT_STEP_TASK_RETRIES = "5" + + +@executor( + name="ecs", + config_schema={ + "run_task_kwargs": Field( + Permissive({}), + is_required=False, + description=( + "Additional arguments to include while running the task. See" + " https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task" + " for the available parameters. The overrides and taskDefinition arguments will" + " always be set by the run launcher." + ), + ), + "cpu": Field(IntSource, is_required=False), + "memory": Field(IntSource, is_required=False), + "ephemeral_storage": Field(IntSource, is_required=False), + "task_overrides": Field( + Permissive({}), + is_required=False, + ), + "retries": get_retries_config(), + "max_concurrent": Field( + IntSource, + is_required=False, + description=( + "Limit on the number of pods that will run concurrently within the scope " + "of a Dagster run. Note that this limit is per run, not global." + ), + ), + "tag_concurrency_limits": get_tag_concurrency_limits_config(), + }, + requirements=multiple_process_executor_requirements(), +) +def ecs_executor(init_context: InitExecutorContext) -> Executor: + """Executor which launches steps as ECS tasks.""" + run_launcher = init_context.instance.run_launcher + + check.invariant( + isinstance(run_launcher, EcsRunLauncher), + "Using the ecs_executor currently requires that the run be launched in an ECS task via the EcsRunLauncher.", + ) + + exc_cfg = init_context.executor_config + + return StepDelegatingExecutor( + EcsStepHandler( + run_task_kwargs=exc_cfg.get("run_task_kwargs"), # type: ignore + cpu=exc_cfg.get("cpu"), # type: ignore + memory=exc_cfg.get("memory"), # type: ignore + ephemeral_storage=exc_cfg.get("ephemeral_storage"), # type: ignore + task_overrides=exc_cfg.get("task_overrides"), # type:ignore + ), + retries=RetryMode.from_config(exc_cfg["retries"]), # type: ignore + max_concurrent=check.opt_int_elem(exc_cfg, "max_concurrent"), + tag_concurrency_limits=check.opt_list_elem(exc_cfg, "tag_concurrency_limits"), + should_verify_step=True, + ) + + +class EcsStepHandler(StepHandler): + @property + def name(self): + return "EcsStepHandler" + + def __init__( + self, + run_task_kwargs: Mapping[str, Any], + cpu: Optional[int], + memory: Optional[int], + ephemeral_storage: Optional[int], + task_overrides: Optional[Mapping[str, Any]], + ): + super().__init__() + + self.ecs = boto3.client("ecs") + + # confusingly, run_task expects cpu and memory value as strings + self._cpu = str(cpu) if cpu else None + self._memory = str(memory) if memory else None + + self._ephemeral_storage = ephemeral_storage + self._task_overrides = check.opt_mapping_param(task_overrides, "task_overrides") + + current_task_metadata = get_current_ecs_task_metadata() + current_task = get_current_ecs_task( + self.ecs, current_task_metadata.task_arn, current_task_metadata.cluster + ) + self._cluster_arn = current_task["clusterArn"] + self._task_definition_arn = current_task["taskDefinitionArn"] + self._run_task_kwargs = { + **run_task_kwargs, + "taskDefinition": current_task["taskDefinitionArn"], + } + + def _get_run_task_kwargs( + self, + run, + args, + step_key: str, + step_tags: Mapping[str, str], + run_launcher: EcsRunLauncher, + container_context: EcsContainerContext, + ): + run_task_kwargs = self._run_task_kwargs + + kwargs_from_tags = step_tags.get("ecs/run_task_kwargs") + if kwargs_from_tags: + run_task_kwargs = {**run_task_kwargs, **json.loads(kwargs_from_tags)} + + run_task_kwargs["tags"] = [ + *run_task_kwargs.get("tags", []), + {"key": "dagster/run_id", "value": run.run_id}, + {"key": "dagster/job_name", "value": run.job_name}, + {"key": "dagster/step_key", "value": step_key}, + ] + + if run.external_job_origin: + run_task_kwargs["tags"] = [ + *run_task_kwargs["tags"], + { + "key": "dagster/code-location", + "value": run.external_job_origin.repository_origin.code_location_origin.location_name, + }, + ] + + overrides = { + # container name has to match since we are assuming we are using the same task + # definition as the run + "containerOverrides": [ + {"name": run_launcher.get_container_name(container_context), "command": args} + ], + **self._get_task_overrides(step_tags), + } + + run_task_kwargs["overrides"] = overrides + + return run_task_kwargs + + def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Mapping[str, str]: + overrides = {} + + cpu = step_tags.get("ecs/cpu", self._cpu) + memory = step_tags.get("ecs/memory", self._memory) + + if cpu: + overrides["cpu"] = cpu + if memory: + overrides["memory"] = memory + + ephemeral_storage = step_tags.get("ecs/ephemeral_storage", self._ephemeral_storage) + + if ephemeral_storage: + overrides["ephemeralStorage"] = {"sizeInGiB": int(ephemeral_storage)} + + tag_overrides = step_tags.get("ecs/task_overrides") + if tag_overrides: + overrides = {**self._task_overrides, **overrides, **json.loads(tag_overrides)} + + return overrides + + def _get_step_key(self, step_handler_context: StepHandlerContext) -> str: + step_keys_to_execute = cast( + List[str], step_handler_context.execute_step_args.step_keys_to_execute + ) + assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" + return step_keys_to_execute[0] + + def _get_container_context( + self, step_handler_context: StepHandlerContext + ) -> EcsContainerContext: + return EcsContainerContext.create_for_run( + step_handler_context.dagster_run, + cast(EcsRunLauncher, step_handler_context.instance.run_launcher), + ) + + def _run_task(self, **run_task_kwargs): + return run_ecs_task(self.ecs, run_task_kwargs) + + def launch_step(self, step_handler_context: StepHandlerContext) -> Optional[str]: + step_key = self._get_step_key(step_handler_context) + + step_tags = step_handler_context.step_tags[step_key] + + container_context = self._get_container_context(step_handler_context) + + run = step_handler_context.dagster_run + + args = step_handler_context.execute_step_args.get_command_args( + skip_serialized_namedtuple=True + ) + + run_task_kwargs = self._get_run_task_kwargs( + run, + args, + step_key, + step_tags, + cast(EcsRunLauncher, step_handler_context.instance.run_launcher), + container_context, + ) + + task = backoff( + self._run_task, + retry_on=(RetryableEcsException,), + kwargs=run_task_kwargs, + max_retries=int( + os.getenv("STEP_TASK_RETRIES", DEFAULT_STEP_TASK_RETRIES), + ), + ) + + DagsterEvent.step_worker_starting( + step_handler_context.get_step_context(step_key), + message=f'Executing step "{step_key}" in ECS task.', + metadata={ + "Task ARN": MetadataValue.text(task["taskArn"]), + }, + ) + + return task["taskArn"] + + def check_step_health( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> CheckStepHealthResult: + step_key = self._get_step_key(step_handler_context) + + task_arn = step_worker_handle + cluster_arn = self._cluster_arn + + tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks") + + if not tasks: + return CheckStepHealthResult.unhealthy( + reason=f"Task {task_arn} for step {step_key} could not be found." + ) + + t = tasks[0] + if t.get("lastStatus") in STOPPED_STATUSES: + failed_containers = [] + for c in t.get("containers"): + if c.get("exitCode") != 0: + failed_containers.append(c) + if len(failed_containers) > 0: + cluster_failure_info = ( + f"Task {t.get('taskArn')} failed.\n" + f"Stop code: {t.get('stopCode')}.\n" + f"Stop reason: {t.get('stoppedReason')}.\n" + ) + for c in failed_containers: + exit_code = c.get("exitCode") + exit_code_msg = f" - exit code {exit_code}" if exit_code is not None else "" + cluster_failure_info += f"Container '{c.get('name')}' failed{exit_code_msg}.\n" + + return CheckStepHealthResult.unhealthy(reason=cluster_failure_info) + + return CheckStepHealthResult.healthy() + + def terminate_step( + self, step_handler_context: StepHandlerContext, step_worker_handle: Optional[str] + ) -> None: + task_arn = step_worker_handle + cluster_arn = self._cluster_arn + step_key = self._get_step_key(step_handler_context) + + DagsterEvent.engine_event( + step_handler_context.get_step_context(step_key), + message=f"Deleting task {task_arn} for step", + event_specific_data=EngineEventData(), + ) + + self.ecs.stop_task(task=task_arn, cluster=cluster_arn) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py index d5990e82c8bef..932100f8769b8 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py @@ -47,7 +47,13 @@ get_task_definition_dict_from_current_task, get_task_kwargs_from_current_task, ) -from dagster_aws.ecs.utils import get_task_definition_family, get_task_logs, task_definitions_match +from dagster_aws.ecs.utils import ( + RetryableEcsException, + get_task_definition_family, + get_task_logs, + run_ecs_task, + task_definitions_match, +) from dagster_aws.secretsmanager import get_secrets_from_arns Tags = namedtuple("Tags", ["arn", "cluster", "cpu", "memory"]) @@ -73,9 +79,6 @@ DEFAULT_RUN_TASK_RETRIES = 5 -class RetryableEcsException(Exception): ... - - class EcsRunLauncher(RunLauncher[T_DagsterInstance], ConfigurableClass): """RunLauncher that starts a task in ECS for each Dagster job run.""" @@ -433,34 +436,7 @@ def _get_image_for_run(self, context: LaunchRunContext) -> Optional[str]: return job_origin.repository_origin.container_image def _run_task(self, **run_task_kwargs): - response = self.ecs.run_task(**run_task_kwargs) - - tasks = response["tasks"] - - if not tasks: - failures = response["failures"] - failure_messages = [] - for failure in failures: - arn = failure.get("arn") - reason = failure.get("reason") - detail = failure.get("detail") - - failure_message = ( - "Task" - + (f" {arn}" if arn else "") - + " failed." - + (f" Failure reason: {reason}" if reason else "") - + (f" Failure details: {detail}" if detail else "") - ) - failure_messages.append(failure_message) - - failure_message = "\n".join(failure_messages) if failure_messages else "Task failed." - - if "Capacity is unavailable at this time" in failure_message: - raise RetryableEcsException(failure_message) - - raise Exception(failure_message) - return tasks[0] + return run_ecs_task(self.ecs, run_task_kwargs) def launch_run(self, context: LaunchRunContext) -> None: """Launch a run in an ECS task.""" @@ -500,7 +476,7 @@ def launch_run(self, context: LaunchRunContext) -> None: container_overrides: List[Dict[str, Any]] = [ { - "name": self._get_container_name(container_context), + "name": self.get_container_name(container_context), "command": command, # containerOverrides expects cpu/memory as integers **{k: int(v) for k, v in cpu_and_memory_overrides.items()}, @@ -676,7 +652,7 @@ def _run_task_kwargs( task_definition_config = DagsterEcsTaskDefinitionConfig( family, image, - self._get_container_name(container_context), + self.get_container_name(container_context), command=None, log_configuration=( { @@ -716,7 +692,7 @@ def _run_task_kwargs( family, self._get_current_task(), image, - self._get_container_name(container_context), + self.get_container_name(container_context), environment=environment, secrets=secrets if secrets else {}, include_sidecars=self.include_sidecars, @@ -734,10 +710,10 @@ def _run_task_kwargs( task_definition_config = DagsterEcsTaskDefinitionConfig.from_task_definition_dict( task_definition_dict, - self._get_container_name(container_context), + self.get_container_name(container_context), ) - container_name = self._get_container_name(container_context) + container_name = self.get_container_name(container_context) backoff( self._reuse_or_register_task_definition, @@ -893,7 +869,7 @@ def check_run_worker_health(self, run: DagsterRun): logs_client=self.logs, cluster=tags.cluster, task_arn=tags.arn, - container_name=self._get_container_name(container_context), + container_name=self.get_container_name(container_context), ) except: logging.exception(f"Error trying to get logs for failed task {tags.arn}") diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py index ec3d9edade381..63627be44a450 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py @@ -19,6 +19,40 @@ def _get_family_hash(name): return f"{name[:55]}_{name_hash}" +class RetryableEcsException(Exception): ... + + +def run_ecs_task(ecs, run_task_kwargs) -> Mapping[str, Any]: + response = ecs.run_task(**run_task_kwargs) + + tasks = response["tasks"] + + if not tasks: + failures = response["failures"] + failure_messages = [] + for failure in failures: + arn = failure.get("arn") + reason = failure.get("reason") + detail = failure.get("detail") + + failure_message = ( + "Task" + + (f" {arn}" if arn else "") + + " failed." + + (f" Failure reason: {reason}" if reason else "") + + (f" Failure details: {detail}" if detail else "") + ) + failure_messages.append(failure_message) + + failure_message = "\n".join(failure_messages) if failure_messages else "Task failed." + + if "Capacity is unavailable at this time" in failure_message: + raise RetryableEcsException(failure_message) + + raise Exception(failure_message) + return tasks[0] + + def get_task_definition_family( prefix: str, job_origin: RemoteJobOrigin, diff --git a/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py b/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py index ab60303b0f8a7..28bcd061b9363 100644 --- a/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py +++ b/python_modules/libraries/dagster-docker/dagster_docker/docker_executor.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Optional, cast +from typing import List, Optional, cast import dagster._check as check import docker @@ -222,7 +222,7 @@ def _create_step_container( **container_kwargs, ) - def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def launch_step(self, step_handler_context: StepHandlerContext) -> None: container_context = self._get_docker_container_context(step_handler_context) client = self._get_client(container_context) @@ -251,7 +251,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" step_key = step_keys_to_execute[0] - yield DagsterEvent.step_worker_starting( + DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message="Launching step in Docker container.", metadata={ @@ -294,7 +294,7 @@ def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckSt reason=f"Container status is {container.status}. Return code is {ret_code}." ) - def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def terminate_step(self, step_handler_context: StepHandlerContext) -> None: container_context = self._get_docker_container_context(step_handler_context) step_keys_to_execute = check.not_none( @@ -307,7 +307,7 @@ def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[D container_name = self._get_container_name(step_handler_context) - yield DagsterEvent.engine_event( + DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Stopping Docker container {container_name} for step.", event_specific_data=EngineEventData(), diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index c8b0d7289d4dc..64bc8d44f8d2c 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -1,4 +1,4 @@ -from typing import Iterator, List, Optional, cast +from typing import List, Optional, cast import kubernetes.config from dagster import ( @@ -261,7 +261,7 @@ def _get_k8s_step_job_name(self, step_handler_context: StepHandlerContext): return "dagster-step-%s" % (name_key) - def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def launch_step(self, step_handler_context: StepHandlerContext) -> None: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) @@ -313,7 +313,7 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags ], ) - yield DagsterEvent.step_worker_starting( + DagsterEvent.step_worker_starting( step_handler_context.get_step_context(step_key), message=f'Executing step "{step_key}" in Kubernetes job {job_name}.', metadata={ @@ -324,7 +324,11 @@ def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[Dags namespace = check.not_none(container_context.namespace) self._api_client.create_namespaced_job_with_retries(body=job, namespace=namespace) - def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: + return None + + def check_step_health( + self, step_handler_context: StepHandlerContext, step_identifier: Optional[str] + ) -> CheckStepHealthResult: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) @@ -346,13 +350,15 @@ def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckSt return CheckStepHealthResult.healthy() - def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + def terminate_step( + self, step_handler_context: StepHandlerContext, step_identifier: str + ) -> None: step_key = self._get_step_key(step_handler_context) job_name = self._get_k8s_step_job_name(step_handler_context) container_context = self._get_container_context(step_handler_context) - yield DagsterEvent.engine_event( + DagsterEvent.engine_event( step_handler_context.get_step_context(step_key), message=f"Deleting Kubernetes job {job_name} for step", event_specific_data=EngineEventData(),