Skip to content

Commit

Permalink
Retry certain kinds of run task failures (#24691)
Browse files Browse the repository at this point in the history
For example, task placement failures.

## Changelog

[dagster-aws] The ECS launcher now automatically retries transient ECS
RunTask failures (like capacity placement failures).

- [x] `NEW`
  • Loading branch information
jmsanders authored Sep 27, 2024
1 parent 1c54a8a commit 0bb7b26
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 25 deletions.
76 changes: 51 additions & 25 deletions python_modules/libraries/dagster-aws/dagster_aws/ecs/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@

DEFAULT_LINUX_RESOURCES = {"cpu": "256", "memory": "512"}

DEFAULT_REGISTER_TASK_DEFINITION_RETRIES = 5
DEFAULT_RUN_TASK_RETRIES = 5


class RetryableEcsException(Exception): ...


class EcsRunLauncher(RunLauncher[T_DagsterInstance], ConfigurableClass):
"""RunLauncher that starts a task in ECS for each Dagster job run."""
Expand Down Expand Up @@ -373,6 +379,36 @@ def _get_image_for_run(self, context: LaunchRunContext) -> Optional[str]:
job_origin = check.not_none(context.job_code_origin)
return job_origin.repository_origin.container_image

def _run_task(self, **run_task_kwargs):
response = self.ecs.run_task(**run_task_kwargs)

tasks = response["tasks"]

if not tasks:
failures = response["failures"]
failure_messages = []
for failure in failures:
arn = failure.get("arn")
reason = failure.get("reason")
detail = failure.get("detail")

failure_message = (
"Task"
+ (f" {arn}" if arn else "")
+ " failed."
+ (f" Failure reason: {reason}" if reason else "")
+ (f" Failure details: {detail}" if detail else "")
)
failure_messages.append(failure_message)

failure_message = "\n".join(failure_messages) if failure_messages else "Task failed."

if "Capacity is unavailable at this time" in failure_message:
raise RetryableEcsException(failure_message)

raise Exception(failure_message)
return tasks[0]

def launch_run(self, context: LaunchRunContext) -> None:
"""Launch a run in an ECS task."""
run = context.dagster_run
Expand Down Expand Up @@ -435,31 +471,17 @@ def launch_run(self, context: LaunchRunContext) -> None:
del run_task_kwargs["launchType"]

# Run a task using the same network configuration as this processes's task.
response = self.ecs.run_task(**run_task_kwargs)

tasks = response["tasks"]

if not tasks:
failures = response["failures"]
failure_messages = []
for failure in failures:
arn = failure.get("arn")
reason = failure.get("reason")
detail = failure.get("detail")

failure_message = (
"Task"
+ (f" {arn}" if arn else "")
+ " failed."
+ (f" Failure reason: {reason}" if reason else "")
+ (f" Failure details: {detail}" if detail else "")
)
failure_messages.append(failure_message)

raise Exception("\n".join(failure_messages) if failure_messages else "Task failed.")
task = backoff(
self._run_task,
retry_on=(RetryableEcsException,),
kwargs=run_task_kwargs,
max_retries=int(
os.getenv("RUN_TASK_RETRIES", DEFAULT_RUN_TASK_RETRIES),
),
)

arn = tasks[0]["taskArn"]
cluster_arn = tasks[0]["clusterArn"]
arn = task["taskArn"]
cluster_arn = task["clusterArn"]
self._set_run_tags(run.run_id, cluster=cluster_arn, task_arn=arn)
self.report_launch_events(run, arn, cluster_arn)

Expand Down Expand Up @@ -661,7 +683,11 @@ def _run_task_kwargs(self, run, image, container_context) -> Dict[str, Any]:
"container_name": container_name,
"task_definition_dict": task_definition_dict,
},
max_retries=5,
max_retries=int(
os.getenv(
"REGISTER_TASK_DEFINITION_RETRIES", DEFAULT_REGISTER_TASK_DEFINITION_RETRIES
),
),
)

task_definition = family
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,40 @@ def run_task(self=ecs, **kwargs):
)
assert ex.match("\nTask missing-detail failed. Failure reason: too succinct\n")
assert ex.match("Task failed. Failure reason: ran out of arns")


def test_run_task_retryrable_failure(ecs, instance, workspace, run, other_run, monkeypatch):
original = ecs.run_task

out_of_capacity_response = {
"tasks": [],
"failures": [
{
"arn": "missing-capacity",
"reason": "Capacity is unavailable at this time. Please try again later or in a different availability zone",
"detail": "boom",
},
],
}

retryable_failures = iter([out_of_capacity_response])

def run_task(*args, **kwargs):
try:
return next(retryable_failures)
except StopIteration:
return original(*args, **kwargs)

instance.run_launcher.ecs.run_task = run_task

instance.launch_run(run.run_id, workspace)

# reset our mock and test again with 0 retries
retryable_failures = iter([out_of_capacity_response])

monkeypatch.setenv("RUN_TASK_RETRIES", "0")

with pytest.raises(Exception) as ex:
instance.launch_run(other_run.run_id, workspace)

assert ex.match("Capacity is unavailable")

0 comments on commit 0bb7b26

Please sign in to comment.