diff --git a/docs/content/api/modules.json.gz b/docs/content/api/modules.json.gz index babd11583e374..00caf49d509e0 100644 Binary files a/docs/content/api/modules.json.gz and b/docs/content/api/modules.json.gz differ diff --git a/docs/content/api/searchindex.json.gz b/docs/content/api/searchindex.json.gz index 5ca4d157c09eb..39b9eab48e8bb 100644 Binary files a/docs/content/api/searchindex.json.gz and b/docs/content/api/searchindex.json.gz differ diff --git a/docs/content/api/sections.json.gz b/docs/content/api/sections.json.gz index 1254495534a4f..c970656e88c06 100644 Binary files a/docs/content/api/sections.json.gz and b/docs/content/api/sections.json.gz differ diff --git a/docs/next/public/objects.inv b/docs/next/public/objects.inv index fd50d97032019..dc486e70947de 100644 Binary files a/docs/next/public/objects.inv and b/docs/next/public/objects.inv differ diff --git a/docs/sphinx/sections/api/apidocs/libraries/dagster-aws.rst b/docs/sphinx/sections/api/apidocs/libraries/dagster-aws.rst index e8774ae92ed5d..15d161197def5 100644 --- a/docs/sphinx/sections/api/apidocs/libraries/dagster-aws.rst +++ b/docs/sphinx/sections/api/apidocs/libraries/dagster-aws.rst @@ -49,6 +49,9 @@ ECS .. autoconfigurable:: dagster_aws.ecs.EcsRunLauncher :annotation: RunLauncher +.. autoconfigurable:: dagster_aws.ecs.ecs_executor + :annotation: ExecutorDefinition + Redshift -------- diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/__init__.py index 1c6711d03db3e..bb36e1a1c554f 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/__init__.py @@ -1,2 +1,3 @@ +from dagster_aws.ecs.executor import ecs_executor as ecs_executor from dagster_aws.ecs.launcher import EcsRunLauncher as EcsRunLauncher from dagster_aws.ecs.tasks import EcsEventualConsistencyTimeout as EcsEventualConsistencyTimeout diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py new file mode 100644 index 0000000000000..1dc2b41761b0f --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/executor.py @@ -0,0 +1,446 @@ +import json +import os +from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, cast + +import boto3 +from dagster import ( + DagsterInvariantViolationError, + DagsterRun, + Field, + IntSource, + Permissive, + _check as check, + executor, +) +from dagster._annotations import experimental +from dagster._core.definitions.executor_definition import multiple_process_executor_requirements +from dagster._core.definitions.metadata import MetadataValue +from dagster._core.events import DagsterEvent, EngineEventData +from dagster._core.execution.retries import RetryMode, get_retries_config +from dagster._core.execution.tags import get_tag_concurrency_limits_config +from dagster._core.executor.base import Executor +from dagster._core.executor.init import InitExecutorContext +from dagster._core.executor.step_delegating import ( + CheckStepHealthResult, + StepDelegatingExecutor, + StepHandler, + StepHandlerContext, +) +from dagster._utils.backoff import backoff +from dagster._utils.merger import deep_merge_dicts + +from dagster_aws.ecs.container_context import EcsContainerContext +from dagster_aws.ecs.launcher import STOPPED_STATUSES, EcsRunLauncher +from dagster_aws.ecs.tasks import ( + get_current_ecs_task, + get_current_ecs_task_metadata, + get_task_kwargs_from_current_task, +) +from dagster_aws.ecs.utils import RetryableEcsException, run_ecs_task + +DEFAULT_STEP_TASK_RETRIES = "5" + + +_ECS_EXECUTOR_CONFIG_SCHEMA = { + "run_task_kwargs": Field( + Permissive({}), + is_required=False, + description=( + "Additional arguments to which can be set to the boto3 run_task call. Will override values inherited from the ECS run launcher." + " https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ecs.html#ECS.Client.run_task" + " for the available parameters." + ), + ), + "cpu": Field(IntSource, is_required=False), + "memory": Field(IntSource, is_required=False), + "ephemeral_storage": Field(IntSource, is_required=False), + "task_overrides": Field( + Permissive({}), + is_required=False, + ), + "retries": get_retries_config(), + "max_concurrent": Field( + IntSource, + is_required=False, + description=( + "Limit on the number of tasks that will run concurrently within the scope " + "of a Dagster run. Note that this limit is per run, not global." + ), + ), + "tag_concurrency_limits": get_tag_concurrency_limits_config(), +} + + +@executor( + name="ecs", + config_schema=_ECS_EXECUTOR_CONFIG_SCHEMA, + requirements=multiple_process_executor_requirements(), +) +@experimental +def ecs_executor(init_context: InitExecutorContext) -> Executor: + """Executor which launches steps as ECS tasks. + + To use the `ecs_executor`, set it as the `executor_def` when defining a job: + + .. literalinclude:: ../../../../../../python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_example_executor_mode_def.py + :start-after: start_marker + :end-before: end_marker + :language: python + + Then you can configure the executor with run config as follows: + + .. code-block:: YAML + + execution: + config: + cpu: 1024 + memory: 2048 + ephemeral_storage: 10 + task_overrides: + containerOverrides: + - name: run + environment: + - name: MY_ENV_VAR + value: "my_value" + + `max_concurrent` limits the number of ECS tasks that will execute concurrently for one run. By default + there is no limit- it will maximally parallel as allowed by the DAG. Note that this is not a + global limit. + + Configuration set on the ECS tasks created by the `ECSRunLauncher` will also be + set on the tasks created by the `ecs_executor`. + + Configuration set using `tags` on a `@job` will only apply to the `run` level. For configuration + to apply at each `step` it must be set using `tags` for each `@op`. + """ + run_launcher = init_context.instance.run_launcher + + check.invariant( + isinstance(run_launcher, EcsRunLauncher), + "Using the ecs_executor currently requires that the run be launched in an ECS task via the EcsRunLauncher.", + ) + + exc_cfg = init_context.executor_config + + return StepDelegatingExecutor( + EcsStepHandler( + run_launcher=run_launcher, # type: ignore + run_task_kwargs=exc_cfg.get("run_task_kwargs"), # type: ignore + cpu=exc_cfg.get("cpu"), # type: ignore + memory=exc_cfg.get("memory"), # type: ignore + ephemeral_storage=exc_cfg.get("ephemeral_storage"), # type: ignore + task_overrides=exc_cfg.get("task_overrides"), # type:ignore + ), + retries=RetryMode.from_config(exc_cfg["retries"]), # type: ignore + max_concurrent=check.opt_int_elem(exc_cfg, "max_concurrent"), + tag_concurrency_limits=check.opt_list_elem(exc_cfg, "tag_concurrency_limits"), + should_verify_step=True, + ) + + +@experimental +class EcsStepHandler(StepHandler): + @property + def name(self): + return "EcsStepHandler" + + def __init__( + self, + run_launcher: EcsRunLauncher, + run_task_kwargs: Optional[Mapping[str, Any]], + cpu: Optional[int], + memory: Optional[int], + ephemeral_storage: Optional[int], + task_overrides: Optional[Mapping[str, Any]], + ): + super().__init__() + + run_task_kwargs = run_task_kwargs or {} + + self.ecs = boto3.client("ecs") + self.ec2 = boto3.resource("ec2") + + # confusingly, run_task expects cpu and memory value as strings + self._cpu = str(cpu) if cpu else None + self._memory = str(memory) if memory else None + + self._ephemeral_storage = ephemeral_storage + self._task_overrides = check.opt_mapping_param(task_overrides, "task_overrides") + + current_task_metadata = get_current_ecs_task_metadata() + current_task = get_current_ecs_task( + self.ecs, current_task_metadata.task_arn, current_task_metadata.cluster + ) + + if run_launcher.use_current_ecs_task_config: + current_task_kwargs = get_task_kwargs_from_current_task( + self.ec2, + current_task_metadata.cluster, + current_task, + ) + else: + current_task_kwargs = {} + + run_launcher_kwargs = {**current_task_kwargs, **run_launcher.run_task_kwargs} + + self._cluster_arn = current_task["clusterArn"] + self._task_definition_arn = current_task["taskDefinitionArn"] + + self._run_task_kwargs = { + "taskDefinition": current_task["taskDefinitionArn"], + **run_launcher_kwargs, + **run_task_kwargs, + } + + # TODO: change launch_step to return task ARN + # this will be a breaking change so we need to wait for a minor release + # to do this + self._launched_tasks = {} + + def _get_run_task_kwargs( + self, + run: DagsterRun, + args: Sequence[str], + step_key: str, + step_tags: Mapping[str, str], + step_handler_context: StepHandlerContext, + container_context: EcsContainerContext, + ): + run_launcher = check.inst( + step_handler_context.instance.run_launcher, + EcsRunLauncher, + "ECS executor can only be enabled with the ECS run launcher", + ) + + run_task_kwargs = self._run_task_kwargs + + kwargs_from_tags = step_tags.get("ecs/run_task_kwargs") + if kwargs_from_tags: + run_task_kwargs = {**run_task_kwargs, **json.loads(kwargs_from_tags)} + + # convert tags to a dictionary for easy value overriding + tags = { + **{tag["key"]: tag["value"] for tag in run_task_kwargs.get("tags", [])}, + **{ + tag["key"]: tag["value"] + for tag in run_launcher.build_ecs_tags_for_run_task(run, container_context) + }, + **step_handler_context.dagster_run.dagster_execution_info, + "dagster/step-key": step_key, + "dagster/step-id": self._get_step_id(step_handler_context), + } + + run_task_kwargs["tags"] = [ + { + "key": key, + "value": value, + } + for key, value in tags.items() + ] + + task_overrides = self._get_task_overrides(step_tags) or {} + + task_overrides["containerOverrides"] = task_overrides.get("containerOverrides", []) + + # container name has to match since we are assuming we are using the same task + executor_container_name = run_launcher.get_container_name(container_context) + executor_env_vars = [ + {"name": env["name"], "value": env["value"]} + for env in step_handler_context.execute_step_args.get_command_env() + ] + + # inject Executor command and env vars into the container overrides + # if they are defined + # otherwise create a new container overrides for the executor container + for container_overrides in task_overrides["containerOverrides"]: + # try to update existing container overrides for the executor container + if container_overrides["name"] == executor_container_name: + if "command" in container_overrides and container_overrides["command"] != args: + raise DagsterInvariantViolationError( + f"The 'command' field for {executor_container_name} container is not allowed in the 'containerOverrides' field of the task overrides." + ) + + # update environment variables & command + container_overrides["command"] = args + container_overrides["environment"] = ( + container_overrides.get("environment", []) + executor_env_vars + ) + break + # if no existing container overrides for the executor container, add new container overrides + else: + task_overrides["containerOverrides"].append( + { + "name": executor_container_name, + "command": args, + "environment": executor_env_vars, + } + ) + + run_task_kwargs["overrides"] = deep_merge_dicts( + run_task_kwargs.get("overrides", {}), task_overrides + ) + + return run_task_kwargs + + def _get_task_overrides(self, step_tags: Mapping[str, str]) -> Dict[str, Any]: + overrides = {**self._task_overrides} + + cpu = step_tags.get("ecs/cpu", self._cpu) + memory = step_tags.get("ecs/memory", self._memory) + + if cpu: + overrides["cpu"] = cpu + if memory: + overrides["memory"] = memory + + ephemeral_storage = step_tags.get("ecs/ephemeral_storage", self._ephemeral_storage) + + if ephemeral_storage: + overrides["ephemeralStorage"] = {"sizeInGiB": int(ephemeral_storage)} + + if tag_overrides := step_tags.get("ecs/task_overrides"): + overrides = deep_merge_dicts(overrides, json.loads(tag_overrides)) + + return overrides + + def _get_step_id(self, step_handler_context: StepHandlerContext): + """Step ID is used to identify the ECS task in the ECS cluster. + It is unique to specific step being executed and takes into account op-level retries. + It's used as a workaround to avoid having to return task ARN from launch_step. + """ + step_key = self._get_step_key(step_handler_context) + + if step_handler_context.execute_step_args.known_state: + retry_count = step_handler_context.execute_step_args.known_state.get_retry_state().get_attempt_count( + step_key + ) + else: + retry_count = 0 + + return "%s-%d" % (step_key, retry_count) + + def _get_step_key(self, step_handler_context: StepHandlerContext) -> str: + step_keys_to_execute = cast( + List[str], step_handler_context.execute_step_args.step_keys_to_execute + ) + assert len(step_keys_to_execute) == 1, "Launching multiple steps is not currently supported" + return step_keys_to_execute[0] + + def _get_container_context( + self, step_handler_context: StepHandlerContext + ) -> EcsContainerContext: + return EcsContainerContext.create_for_run( + step_handler_context.dagster_run, + cast(EcsRunLauncher, step_handler_context.instance.run_launcher), + ) + + def _run_task(self, **run_task_kwargs): + return run_ecs_task(self.ecs, run_task_kwargs) + + def launch_step(self, step_handler_context: StepHandlerContext) -> Iterator[DagsterEvent]: + step_key = self._get_step_key(step_handler_context) + + step_tags = step_handler_context.step_tags[step_key] + + container_context = self._get_container_context(step_handler_context) + + run = step_handler_context.dagster_run + + args = step_handler_context.execute_step_args.get_command_args( + skip_serialized_namedtuple=True + ) + + run_task_kwargs = self._get_run_task_kwargs( + run, + args, + step_key, + step_tags, + step_handler_context=step_handler_context, + container_context=container_context, + ) + + task = backoff( + self._run_task, + retry_on=(RetryableEcsException,), + kwargs=run_task_kwargs, + max_retries=int( + os.getenv("STEP_TASK_RETRIES", DEFAULT_STEP_TASK_RETRIES), + ), + ) + + yield DagsterEvent.step_worker_starting( + step_handler_context.get_step_context(step_key), + message=f'Executing step "{step_key}" in ECS task.', + metadata={ + "Task ARN": MetadataValue.text(task["taskArn"]), + }, + ) + + step_id = self._get_step_id(step_handler_context) + + self._launched_tasks[step_id] = task["taskArn"] + + def check_step_health(self, step_handler_context: StepHandlerContext) -> CheckStepHealthResult: + step_key = self._get_step_key(step_handler_context) + step_id = self._get_step_id(step_handler_context) + + try: + task_arn = self._launched_tasks[step_id] + except KeyError: + return CheckStepHealthResult.unhealthy( + reason=f"Task ARN for step {step_key} could not be found in executor's task map. This is likely a bug." + ) + + cluster_arn = self._cluster_arn + + tasks = self.ecs.describe_tasks(tasks=[task_arn], cluster=cluster_arn).get("tasks") + + if not tasks: + return CheckStepHealthResult.unhealthy( + reason=f"Task {task_arn} for step {step_key} could not be found." + ) + + t = tasks[0] + if t.get("lastStatus") in STOPPED_STATUSES: + failed_containers = [] + for c in t.get("containers"): + if c.get("exitCode") != 0: + failed_containers.append(c) + if len(failed_containers) > 0: + cluster_failure_info = ( + f"Task {t.get('taskArn')} failed.\n" + f"Stop code: {t.get('stopCode')}.\n" + f"Stop reason: {t.get('stoppedReason')}.\n" + ) + for c in failed_containers: + exit_code = c.get("exitCode") + exit_code_msg = f" - exit code {exit_code}" if exit_code is not None else "" + cluster_failure_info += f"Container '{c.get('name')}' failed{exit_code_msg}.\n" + + return CheckStepHealthResult.unhealthy(reason=cluster_failure_info) + + return CheckStepHealthResult.healthy() + + def terminate_step( + self, + step_handler_context: StepHandlerContext, + ) -> None: + step_id = self._get_step_id(step_handler_context) + step_key = self._get_step_key(step_handler_context) + + try: + task_arn = self._launched_tasks[step_id] + except KeyError: + raise DagsterInvariantViolationError( + f"Task ARN for step {step_key} could not be found in executor's task map. This is likely a bug." + ) + + cluster_arn = self._cluster_arn + + DagsterEvent.engine_event( + step_handler_context.get_step_context(step_key), + message=f"Stopping task {task_arn} for step", + event_specific_data=EngineEventData(), + ) + + self.ecs.stop_task(task=task_arn, cluster=cluster_arn) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py index 11ad43735a8ee..a5b8efb868e23 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py @@ -48,7 +48,13 @@ get_task_definition_dict_from_current_task, get_task_kwargs_from_current_task, ) -from dagster_aws.ecs.utils import get_task_definition_family, get_task_logs, task_definitions_match +from dagster_aws.ecs.utils import ( + RetryableEcsException, + get_task_definition_family, + get_task_logs, + run_ecs_task, + task_definitions_match, +) from dagster_aws.secretsmanager import get_secrets_from_arns Tags = namedtuple("Tags", ["arn", "cluster", "cpu", "memory"]) @@ -74,9 +80,6 @@ DEFAULT_RUN_TASK_RETRIES = 5 -class RetryableEcsException(Exception): ... - - class EcsRunLauncher(RunLauncher[T_DagsterInstance], ConfigurableClass): """RunLauncher that starts a task in ECS for each Dagster job run.""" @@ -429,39 +432,16 @@ def _get_run_tags(self, run_id: str) -> Tags: def _get_command_args(self, run_args: ExecuteRunArgs, context: LaunchRunContext): return run_args.get_command_args() - def _get_image_for_run(self, context: LaunchRunContext) -> Optional[str]: - job_origin = check.not_none(context.job_code_origin) - return job_origin.repository_origin.container_image + @staticmethod + def get_image_for_run(run: DagsterRun) -> Optional[str]: + return ( + run.job_code_origin.repository_origin.container_image + if run.job_code_origin is not None + else None + ) def _run_task(self, **run_task_kwargs): - response = self.ecs.run_task(**run_task_kwargs) - - tasks = response["tasks"] - - if not tasks: - failures = response["failures"] - failure_messages = [] - for failure in failures: - arn = failure.get("arn") - reason = failure.get("reason") - detail = failure.get("detail") - - failure_message = ( - "Task" - + (f" {arn}" if arn else "") - + " failed." - + (f" Failure reason: {reason}" if reason else "") - + (f" Failure details: {detail}" if detail else "") - ) - failure_messages.append(failure_message) - - failure_message = "\n".join(failure_messages) if failure_messages else "Task failed." - - if "Capacity is unavailable at this time" in failure_message: - raise RetryableEcsException(failure_message) - - raise Exception(failure_message) - return tasks[0] + return run_ecs_task(self.ecs, run_task_kwargs) def launch_run(self, context: LaunchRunContext) -> None: """Launch a run in an ECS task.""" @@ -487,7 +467,7 @@ def launch_run(self, context: LaunchRunContext) -> None: instance_ref=self._instance.get_ref(), ) command = self._get_command_args(args, context) - image = self._get_image_for_run(context) + image = self.get_image_for_run(run) run_task_kwargs = self._run_task_kwargs(run, image, container_context) @@ -499,7 +479,7 @@ def launch_run(self, context: LaunchRunContext) -> None: container_overrides: List[Dict[str, Any]] = [ { - "name": self._get_container_name(container_context), + "name": self.get_container_name(container_context), "command": command, # containerOverrides expects cpu/memory as integers **{k: int(v) for k, v in cpu_and_memory_overrides.items()}, @@ -644,7 +624,7 @@ def _get_current_task(self): def _get_run_task_definition_family(self, run: DagsterRun) -> str: return get_task_definition_family("run", check.not_none(run.remote_job_origin)) - def _get_container_name(self, container_context: EcsContainerContext) -> str: + def get_container_name(self, container_context: EcsContainerContext) -> str: return container_context.container_name or self.container_name def _run_task_kwargs( @@ -675,7 +655,7 @@ def _run_task_kwargs( task_definition_config = DagsterEcsTaskDefinitionConfig( family, image, - self._get_container_name(container_context), + self.get_container_name(container_context), command=None, log_configuration=( { @@ -715,7 +695,7 @@ def _run_task_kwargs( family, self._get_current_task(), image, - self._get_container_name(container_context), + self.get_container_name(container_context), environment=environment, secrets=secrets if secrets else {}, include_sidecars=self.include_sidecars, @@ -733,10 +713,10 @@ def _run_task_kwargs( task_definition_config = DagsterEcsTaskDefinitionConfig.from_task_definition_dict( task_definition_dict, - self._get_container_name(container_context), + self.get_container_name(container_context), ) - container_name = self._get_container_name(container_context) + container_name = self.get_container_name(container_context) backoff( self._reuse_or_register_task_definition, @@ -897,7 +877,7 @@ def check_run_worker_health(self, run: DagsterRun): logs_client=self.logs, cluster=tags.cluster, task_arn=tags.arn, - container_name=self._get_container_name(container_context), + container_name=self.get_container_name(container_context), ) except: logging.exception(f"Error trying to get logs for failed task {tags.arn}") diff --git a/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py b/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py index ec3d9edade381..63627be44a450 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/ecs/utils.py @@ -19,6 +19,40 @@ def _get_family_hash(name): return f"{name[:55]}_{name_hash}" +class RetryableEcsException(Exception): ... + + +def run_ecs_task(ecs, run_task_kwargs) -> Mapping[str, Any]: + response = ecs.run_task(**run_task_kwargs) + + tasks = response["tasks"] + + if not tasks: + failures = response["failures"] + failure_messages = [] + for failure in failures: + arn = failure.get("arn") + reason = failure.get("reason") + detail = failure.get("detail") + + failure_message = ( + "Task" + + (f" {arn}" if arn else "") + + " failed." + + (f" Failure reason: {reason}" if reason else "") + + (f" Failure details: {detail}" if detail else "") + ) + failure_messages.append(failure_message) + + failure_message = "\n".join(failure_messages) if failure_messages else "Task failed." + + if "Capacity is unavailable at this time" in failure_message: + raise RetryableEcsException(failure_message) + + raise Exception(failure_message) + return tasks[0] + + def get_task_definition_family( prefix: str, job_origin: RemoteJobOrigin, diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_example_executor_mode_def.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_example_executor_mode_def.py new file mode 100644 index 0000000000000..12aa4f15f3cb6 --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_example_executor_mode_def.py @@ -0,0 +1,26 @@ +# ruff: isort: skip_file +# fmt: off +# start_marker +from dagster_aws.ecs import ecs_executor + +from dagster import job, op + + +@op( + tags={"ecs/cpu": "256", "ecs/memory": "512"}, +) +def ecs_op(): + pass + + +@job(executor_def=ecs_executor) +def ecs_job(): + ecs_op() + + +# end_marker +# fmt: on + + +def test_mode(): + assert ecs_job diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py new file mode 100644 index 0000000000000..87e25aeadd598 --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/ecs_tests/launcher_tests/executor_tests/test_executor.py @@ -0,0 +1,245 @@ +from typing import Callable, ContextManager + +from dagster import job, op, repository +from dagster._config import process_config, resolve_to_config_type +from dagster._core.definitions.reconstruct import reconstructable +from dagster._core.execution.api import create_execution_plan +from dagster._core.execution.context.system import PlanData, PlanOrchestrationContext +from dagster._core.execution.context_creation_job import create_context_free_log_manager +from dagster._core.execution.retries import RetryMode +from dagster._core.executor.init import InitExecutorContext +from dagster._core.executor.step_delegating.step_handler.base import StepHandlerContext +from dagster._core.instance import DagsterInstance +from dagster._core.remote_representation.handle import RepositoryHandle +from dagster._core.storage.fs_io_manager import fs_io_manager +from dagster._core.test_utils import create_run_for_test, in_process_test_workspace +from dagster._core.types.loadable_target_origin import LoadableTargetOrigin +from dagster._grpc.types import ExecuteStepArgs +from dagster._utils.hosted_user_process import remote_job_from_recon_job + +from dagster_aws.ecs.executor import _ECS_EXECUTOR_CONFIG_SCHEMA, ecs_executor + + +@job( + executor_def=ecs_executor, + resource_defs={"io_manager": fs_io_manager}, +) +def bar(): + @op( + tags={ + "ecs/cpu": "1024", + "ecs/memory": "512", + } + ) + def foo(): + return 1 + + foo() + + +@repository +def bar_repo(): + return [bar] + + +def _get_executor(instance, job_def, executor_config=None): + process_result = process_config( + resolve_to_config_type(_ECS_EXECUTOR_CONFIG_SCHEMA), executor_config or {} + ) + if not process_result.success: + raise AssertionError(f"Process result errors: {process_result.errors}") + + return ecs_executor.executor_creation_fn( # type: ignore + InitExecutorContext( + job=job_def, + executor_def=ecs_executor, + executor_config=process_result.value, # type: ignore + instance=instance, + ) + ) + + +def _step_handler_context(job_def, dagster_run, instance, executor): + execution_plan = create_execution_plan(job_def) + log_manager = create_context_free_log_manager(instance, dagster_run) + + plan_context = PlanOrchestrationContext( + plan_data=PlanData( + job=job_def, + dagster_run=dagster_run, + instance=instance, + execution_plan=execution_plan, + raise_on_error=True, + retry_mode=RetryMode.DISABLED, + ), + log_manager=log_manager, + executor=executor, + output_capture=None, + ) + + execute_step_args = ExecuteStepArgs( + reconstructable(bar).get_python_origin(), + dagster_run.run_id, + ["foo"], + print_serialized_events=False, + ) + + return StepHandlerContext( + instance=instance, + plan_context=plan_context, + steps=execution_plan.steps, # type: ignore + execute_step_args=execute_step_args, + ) + + +def test_executor_init(instance_cm: Callable[..., ContextManager[DagsterInstance]]): + with instance_cm() as instance: + recon_job = reconstructable(bar) + loadable_target_origin = LoadableTargetOrigin(python_file=__file__, attribute="bar_repo") + + memory = 128 + cpu = 500 + env_var = {"key": "OVERRIDE_VAR", "value": "foo"} + executor = _get_executor( + instance, + reconstructable(bar), + { + "cpu": cpu, + "memory": memory, + "task_overrides": { + "containerOverrides": [ + { + "name": "run", + "environment": [env_var], + } + ], + }, + }, + ) + + with in_process_test_workspace( + instance, loadable_target_origin, container_image="testing/dagster" + ) as workspace: + location = workspace.get_code_location(workspace.code_location_names[0]) + repo_handle = RepositoryHandle.from_location( + repository_name="bar_repo", + code_location=location, + ) + fake_remote_job = remote_job_from_recon_job( + recon_job, + op_selection=None, + repository_handle=repo_handle, + ) + + run = create_run_for_test( + instance, + job_name="bar", + remote_job_origin=fake_remote_job.get_remote_origin(), + job_code_origin=recon_job.get_python_origin(), + ) + step_handler_context = _step_handler_context( + job_def=reconstructable(bar), + dagster_run=run, + instance=instance, + executor=executor, + ) + run_task_kwargs = executor._step_handler._get_run_task_kwargs( # type: ignore # noqa: SLF001 + run, + ["my-command"], + "asdasd", + {}, + step_handler_context, + executor._step_handler._get_container_context(step_handler_context), # type: ignore # noqa: SLF001 + ) + + assert run_task_kwargs["launchType"] == "FARGATE" # this comes from the Run Launcher + + overrides = run_task_kwargs["overrides"] + + assert overrides["cpu"] == str(cpu) + assert overrides["memory"] == str(memory) + + run_container_overrides = overrides["containerOverrides"][0] + + assert run_container_overrides["name"] == "run" + assert run_container_overrides["command"] == ["my-command"] + + assert env_var in run_container_overrides["environment"] + + +def test_executor_launch(instance_cm: Callable[..., ContextManager[DagsterInstance]]): + with instance_cm() as instance: + recon_job = reconstructable(bar) + loadable_target_origin = LoadableTargetOrigin(python_file=__file__, attribute="bar_repo") + + with in_process_test_workspace( + instance, loadable_target_origin, container_image="testing/dagster" + ) as workspace: + location = workspace.get_code_location(workspace.code_location_names[0]) + repo_handle = RepositoryHandle.from_location( + repository_name="bar_repo", + code_location=location, + ) + fake_remote_job = remote_job_from_recon_job( + recon_job, + op_selection=None, + repository_handle=repo_handle, + ) + + executor = _get_executor(instance, reconstructable(bar), {}) + run = create_run_for_test( + instance, + job_name="bar", + remote_job_origin=fake_remote_job.get_remote_origin(), + job_code_origin=recon_job.get_python_origin(), + ) + step_handler_context = _step_handler_context( + job_def=reconstructable(bar), + dagster_run=run, + instance=instance, + executor=executor, + ) + from unittest.mock import MagicMock + + executor._step_handler.ecs.run_task = MagicMock( # type: ignore # noqa: SLF001 + return_value={"tasks": [{"taskArn": "arn:123"}]} + ) + + next(iter(executor._step_handler.launch_step(step_handler_context))) # type: ignore # noqa: SLF001 + + run_task_kwargs = executor._step_handler.ecs.run_task.call_args[1] # type: ignore # noqa: SLF001 + + # resources should come from step tags + assert run_task_kwargs["overrides"]["cpu"] == "1024" + assert run_task_kwargs["overrides"]["memory"] == "512" + + tags = run_task_kwargs["tags"] + + assert { + "key": "dagster/run-id", + "value": run.run_id, + } in tags + + assert { + "key": "dagster/job", + "value": run.job_name, + } in tags + + assert { + "key": "dagster/step-key", + "value": "foo", + } in tags + + assert run_task_kwargs["overrides"]["containerOverrides"][0]["command"] == [ + "dagster", + "api", + "execute_step", + ] + + found_executor_args_var = False + for var in run_task_kwargs["overrides"]["containerOverrides"][0]["environment"]: + if var["name"] == "DAGSTER_COMPRESSED_EXECUTE_STEP_ARGS": + found_executor_args_var = True + break + + assert found_executor_args_var