From 6fb82d91be4a4e02efc627e0a9d31b72bd9f2923 Mon Sep 17 00:00:00 2001 From: Daniel Gafni Date: Mon, 14 Oct 2024 22:31:10 +0200 Subject: [PATCH] [dagster-aws] add `PipesEMRClient` (#23998) ## Summary & Motivation This PR adds `PipesEMRCLient` to `dagster-aws`. It allows running Spark workloads in ephemeral EMR (EC2 flavor) clusters. There is no support for submitting steps to existing EMR clusters. It turned out to be pretty hard to support properly since there is no native way to distinguish between steps submitted by our client and other workloads running in the same cluster. We can try to implement this in the future. Tasks: - [x] `PipesEMRClient` implementation - [x] local testing Showcase: ![image.png](https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/MvopLwsR8lA4fwP4UJdE/f9299522-d488-4adf-a549-339477fa3498.png) ## How I Tested These Changes ## Changelog [New | Bug | Docs] [dagster-aws] new AWS EMR Dagster Pipes client (`dagster_aws.pipes.PipesEMRCLient` ) for running and monitoring AWS EMR jobs from Dagster. --- pyright/alt-1/requirements-pinned.txt | 1 + pyright/master/requirements-pinned.txt | 1 + .../dagster-aws/dagster_aws/pipes/__init__.py | 2 + .../dagster_aws/pipes/clients/__init__.py | 9 +- .../dagster_aws/pipes/clients/emr.py | 339 ++++++++++++++++++ .../dagster_aws/pipes/message_readers.py | 5 + .../libraries/dagster-aws/ruff.toml | 5 +- python_modules/libraries/dagster-aws/setup.py | 3 +- 8 files changed, 362 insertions(+), 3 deletions(-) create mode 100644 python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py diff --git a/pyright/alt-1/requirements-pinned.txt b/pyright/alt-1/requirements-pinned.txt index 9adabce9b43d1..194895a13f60f 100644 --- a/pyright/alt-1/requirements-pinned.txt +++ b/pyright/alt-1/requirements-pinned.txt @@ -171,6 +171,7 @@ multidict==6.1.0 multimethod==1.10 mypy==1.11.2 mypy-boto3-ecs==1.35.21 +mypy-boto3-emr==1.35.18 mypy-boto3-emr-serverless==1.35.25 mypy-boto3-glue==1.35.25 mypy-boto3-s3==1.35.32 diff --git a/pyright/master/requirements-pinned.txt b/pyright/master/requirements-pinned.txt index 549546fa6b1a3..b7c91d5997efb 100644 --- a/pyright/master/requirements-pinned.txt +++ b/pyright/master/requirements-pinned.txt @@ -360,6 +360,7 @@ msgpack==1.1.0 multidict==6.1.0 multimethod==1.10 mypy-boto3-ecs==1.35.21 +mypy-boto3-emr==1.35.18 mypy-boto3-emr-serverless==1.35.25 mypy-boto3-glue==1.35.25 mypy-boto3-s3==1.35.32 diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py index 3902a5a642f31..ef95c70ae21d3 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/__init__.py @@ -1,5 +1,6 @@ from dagster_aws.pipes.clients import ( PipesECSClient, + PipesEMRClient, PipesEMRServerlessClient, PipesGlueClient, PipesLambdaClient, @@ -19,6 +20,7 @@ "PipesGlueClient", "PipesLambdaClient", "PipesECSClient", + "PipesEMRClient", "PipesS3ContextInjector", "PipesLambdaEventContextInjector", "PipesS3MessageReader", diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py index a6711bbf1ce82..e3895ac00d101 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/__init__.py @@ -1,6 +1,13 @@ from dagster_aws.pipes.clients.ecs import PipesECSClient +from dagster_aws.pipes.clients.emr import PipesEMRClient from dagster_aws.pipes.clients.emr_serverless import PipesEMRServerlessClient from dagster_aws.pipes.clients.glue import PipesGlueClient from dagster_aws.pipes.clients.lambda_ import PipesLambdaClient -__all__ = ["PipesGlueClient", "PipesLambdaClient", "PipesECSClient", "PipesEMRServerlessClient"] +__all__ = [ + "PipesGlueClient", + "PipesLambdaClient", + "PipesECSClient", + "PipesEMRServerlessClient", + "PipesEMRClient", +] diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py new file mode 100644 index 0000000000000..417c13fca2c7d --- /dev/null +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py @@ -0,0 +1,339 @@ +import os +import sys +import time +from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast + +import boto3 +import dagster._check as check +from dagster import PipesClient +from dagster._annotations import public +from dagster._core.definitions.resource_annotation import TreatAsResourceParam +from dagster._core.errors import DagsterExecutionInterruptedError +from dagster._core.execution.context.compute import OpExecutionContext +from dagster._core.pipes.client import ( + PipesClientCompletedInvocation, + PipesContextInjector, + PipesMessageReader, +) +from dagster._core.pipes.utils import PipesEnvContextInjector, PipesSession, open_pipes_session + +from dagster_aws.emr.emr import EMR_CLUSTER_TERMINATED_STATES +from dagster_aws.pipes.message_readers import ( + PipesS3LogReader, + PipesS3MessageReader, + gzip_log_decode_fn, +) + +if TYPE_CHECKING: + from mypy_boto3_emr import EMRClient + from mypy_boto3_emr.literals import ClusterStateType + from mypy_boto3_emr.type_defs import ( + ConfigurationUnionTypeDef, + DescribeClusterOutputTypeDef, + RunJobFlowInputRequestTypeDef, + RunJobFlowOutputTypeDef, + ) + + +def add_configuration( + configurations: List["ConfigurationUnionTypeDef"], + configuration: "ConfigurationUnionTypeDef", +): + """Add a configuration to a list of EMR configurations, merging configurations with the same classification. + + This is necessary because EMR doesn't accept multiple configurations with the same classification. + """ + for existing_configuration in configurations: + if existing_configuration.get("Classification") is not None and existing_configuration.get( + "Classification" + ) == configuration.get("Classification"): + properties = {**existing_configuration.get("Properties", {})} + properties.update(properties) + + inner_configurations = cast( + List["ConfigurationUnionTypeDef"], existing_configuration.get("Configurations", []) + ) + + for inner_configuration in cast( + List["ConfigurationUnionTypeDef"], configuration.get("Configurations", []) + ): + add_configuration(inner_configurations, inner_configuration) + + existing_configuration["Properties"] = properties + existing_configuration["Configurations"] = inner_configurations # type: ignore + + break + else: + configurations.append(configuration) + + +class PipesEMRClient(PipesClient, TreatAsResourceParam): + """A pipes client for running jobs on AWS EMR. + + Args: + message_reader (Optional[PipesMessageReader]): A message reader to use to read messages + from the EMR jobs. + Recommended to use :py:class:`PipesS3MessageReader` with `expect_s3_message_writer` set to `True`. + client (Optional[boto3.client]): The boto3 EMR client used to interact with AWS EMR. + context_injector (Optional[PipesContextInjector]): A context injector to use to inject + context into AWS EMR job. Defaults to :py:class:`PipesEnvContextInjector`. + forward_termination (bool): Whether to cancel the EMR job if the Dagster process receives a termination signal. + wait_for_s3_logs_seconds (int): The number of seconds to wait for S3 logs to be written after execution completes. + """ + + def __init__( + self, + message_reader: PipesMessageReader, + client=None, + context_injector: Optional[PipesContextInjector] = None, + forward_termination: bool = True, + wait_for_s3_logs_seconds: int = 10, + ): + self._client = client or boto3.client("emr") + self._message_reader = message_reader + self._context_injector = context_injector or PipesEnvContextInjector() + self.forward_termination = check.bool_param(forward_termination, "forward_termination") + self.wait_for_s3_logs_seconds = wait_for_s3_logs_seconds + + @property + def client(self) -> "EMRClient": + return self._client + + @property + def context_injector(self) -> PipesContextInjector: + return self._context_injector + + @property + def message_reader(self) -> PipesMessageReader: + return self._message_reader + + @classmethod + def _is_dagster_maintained(cls) -> bool: + return True + + @public + def run( + self, + *, + context: OpExecutionContext, + run_job_flow_params: "RunJobFlowInputRequestTypeDef", + extras: Optional[Dict[str, Any]] = None, + ) -> PipesClientCompletedInvocation: + """Run a job on AWS EMR, enriched with the pipes protocol. + + Starts a new EMR cluster for each invocation. + + Args: + context (OpExecutionContext): The context of the currently executing Dagster op or asset. + run_job_flow_params (Optional[dict]): Parameters for the ``run_job_flow`` boto3 EMR client call. + See `Boto3 API Documentation `_ + extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session in the external process. + + Returns: + PipesClientCompletedInvocation: Wrapper containing results reported by the external + process. + """ + with open_pipes_session( + context=context, + message_reader=self.message_reader, + context_injector=self.context_injector, + extras=extras, + ) as session: + run_job_flow_params = self._enrich_params(session, run_job_flow_params) + start_response = self._start(context, session, run_job_flow_params) + try: + self._add_log_readers(context, start_response) + wait_response = self._wait_for_completion(context, start_response) + self._read_remaining_logs(context, wait_response) + return PipesClientCompletedInvocation(session) + + except DagsterExecutionInterruptedError: + if self.forward_termination: + context.log.warning( + "[pipes] Dagster process interrupted! Will terminate external EMR job." + ) + self._terminate(context, start_response) + raise + + def _enrich_params( + self, session: PipesSession, params: "RunJobFlowInputRequestTypeDef" + ) -> "RunJobFlowInputRequestTypeDef": + # add Pipes env variables + pipes_env_vars = session.get_bootstrap_env_vars() + + configurations = cast(List["ConfigurationUnionTypeDef"], params.get("Configurations", [])) + + # add all possible env vars to spark-defaults, spark-env, yarn-env, hadoop-env + # since we can't be sure which one will be used by the job + add_configuration( + configurations, + { + "Classification": "spark-defaults", + "Properties": { + f"spark.yarn.appMasterEnv.{var}": value for var, value in pipes_env_vars.items() + }, + }, + ) + + for classification in ["spark-env", "yarn-env", "hadoop-env"]: + add_configuration( + configurations, + { + "Classification": classification, + "Configurations": [ + { + "Classification": "export", + "Properties": pipes_env_vars, + } + ], + }, + ) + + params["Configurations"] = configurations + + tags = list(params.get("Tags", [])) + + for key, value in session.default_remote_invocation_info.items(): + tags.append({"Key": key, "Value": value}) + + params["Tags"] = tags + + return params + + def _start( + self, + context: OpExecutionContext, + session: PipesSession, + params: "RunJobFlowInputRequestTypeDef", + ) -> "RunJobFlowOutputTypeDef": + response = self._client.run_job_flow(**params) + + session.report_launched({"extras": response}) + + cluster_id = response["JobFlowId"] + + context.log.info(f"[pipes] EMR steps started in cluster {cluster_id}") + return response + + def _wait_for_completion( + self, context: OpExecutionContext, response: "RunJobFlowOutputTypeDef" + ) -> "DescribeClusterOutputTypeDef": + cluster_id = response["JobFlowId"] + self._client.get_waiter("cluster_running").wait(ClusterId=cluster_id) + context.log.info(f"[pipes] EMR cluster {cluster_id} running") + # now wait for the job to complete + self._client.get_waiter("cluster_terminated").wait(ClusterId=cluster_id) + + cluster = self._client.describe_cluster(ClusterId=cluster_id) + + state: ClusterStateType = cluster["Cluster"]["Status"]["State"] + + context.log.info(f"[pipes] EMR cluster {cluster_id} completed with state: {state}") + + if state in EMR_CLUSTER_TERMINATED_STATES: + context.log.error(f"[pipes] EMR job {cluster_id} failed") + raise Exception(f"[pipes] EMR job {cluster_id} failed:\n{cluster}") + + return cluster + + def _add_log_readers(self, context: OpExecutionContext, response: "RunJobFlowOutputTypeDef"): + cluster = self.client.describe_cluster(ClusterId=response["JobFlowId"]) + + cluster_id = cluster["Cluster"]["Id"] # type: ignore + logs_uri = cluster.get("Cluster", {}).get("LogUri", {}) + + if isinstance(self.message_reader, PipesS3MessageReader) and logs_uri is None: + context.log.warning( + "[pipes] LogUri is not set in the EMR cluster configuration. Won't be able to read logs." + ) + elif isinstance(self.message_reader, PipesS3MessageReader) and isinstance(logs_uri, str): + bucket = logs_uri.split("/")[2] + prefix = "/".join(logs_uri.split("/")[3:]) + + steps = self.client.list_steps(ClusterId=cluster_id) + + # forward stdout and stderr from each step + + for step in steps["Steps"]: + step_id = step["Id"] # type: ignore + + for stdio in ["stdout", "stderr"]: + # at this stage we can't know if this key will be created + # for example, if a step doesn't have any stdout/stderr logs + # the PipesS3LogReader won't be able to start + # this may result in some unnecessary warnings + # there is not much we can do about it except perform step logs reading + # after the job is completed, which is not ideal too + key = os.path.join(prefix, f"{cluster_id}/steps/{step_id}/{stdio}.gz") + + self.message_reader.add_log_reader( + log_reader=PipesS3LogReader( + client=self.message_reader.client, + bucket=bucket, + key=key, + decode_fn=gzip_log_decode_fn, + target_stream=sys.stdout if stdio == "stdout" else sys.stderr, + debug_info=f"reader for {stdio} of EMR step {step_id}", + ), + ) + + def _read_remaining_logs( + self, context: OpExecutionContext, response: "DescribeClusterOutputTypeDef" + ): + cluster_id = response["Cluster"]["Id"] # type: ignore + logs_uri = response.get("Cluster", {}).get("LogUri", {}) + + if isinstance(self.message_reader, PipesS3MessageReader) and isinstance(logs_uri, str): + bucket = logs_uri.split("/")[2] + prefix = "/".join(logs_uri.split("/")[3:]) + + # discover container (application) logs (e.g. Python logs) and forward all of them + # ex. /containers/application_1727881613116_0001/container_1727881613116_0001_01_000001/stdout.gz + containers_prefix = os.path.join(prefix, f"{cluster_id}/containers/") + + context.log.debug( + f"[pipes] Waiting for {self.wait_for_s3_logs_seconds} seconds to allow EMR to dump all logs to S3. " + "Consider increasing this value if some logs are missing." + ) + + time.sleep(self.wait_for_s3_logs_seconds) # give EMR a chance to dump all logs to S3 + + context.log.debug( + f"[pipes] Looking for application logs in s3://{os.path.join(bucket, containers_prefix)}" + ) + + all_keys = [ + obj["Key"] + for obj in self.message_reader.client.list_objects_v2( + Bucket=bucket, Prefix=containers_prefix + )["Contents"] + ] + + # filter keys which include stdout.gz or stderr.gz + + container_log_keys = {} + for key in all_keys: + if "stdout.gz" in key: + container_log_keys[key] = "stdout" + elif "stderr.gz" in key: + container_log_keys[key] = "stderr" + + # forward application logs + + for key, stdio in container_log_keys.items(): + container_id = key.split("/")[-2] + self.message_reader.add_log_reader( + log_reader=PipesS3LogReader( + client=self.message_reader.client, + bucket=bucket, + key=key, + decode_fn=gzip_log_decode_fn, + target_stream=sys.stdout if stdio == "stdout" else sys.stderr, + debug_info=f"log reader for container {container_id} {stdio}", + ), + ) + + def _terminate(self, context: OpExecutionContext, start_response: "RunJobFlowOutputTypeDef"): + cluster_id = start_response["JobFlowId"] + context.log.info(f"[pipes] Terminating EMR job {cluster_id}") + self._client.terminate_job_flows(JobFlowIds=[cluster_id]) diff --git a/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py b/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py index fd5ce2117a988..6f00364908a62 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py +++ b/python_modules/libraries/dagster-aws/dagster_aws/pipes/message_readers.py @@ -1,4 +1,5 @@ import base64 +import gzip import os import random import string @@ -53,6 +54,10 @@ def default_log_decode_fn(contents: bytes) -> str: return contents.decode("utf-8") +def gzip_log_decode_fn(contents: bytes) -> str: + return gzip.decompress(contents).decode("utf-8") + + class PipesS3LogReader(PipesChunkedLogReader): def __init__( self, diff --git a/python_modules/libraries/dagster-aws/ruff.toml b/python_modules/libraries/dagster-aws/ruff.toml index 244f818c1fc43..64c0e85362012 100644 --- a/python_modules/libraries/dagster-aws/ruff.toml +++ b/python_modules/libraries/dagster-aws/ruff.toml @@ -8,7 +8,10 @@ extend-select = [ [lint.flake8-tidy-imports] banned-module-level-imports = [ + "mypy_boto3_s3", + "mypy_boto3_logs", "mypy_boto3_ecs", "mypy_boto3_glue", - "mypy_boto3_emr_serverless" + "mypy_boto3_emr_serverless", + "mypy_boto3_emr" ] diff --git a/python_modules/libraries/dagster-aws/setup.py b/python_modules/libraries/dagster-aws/setup.py index 85ccfd520abc7..9923d70197e31 100644 --- a/python_modules/libraries/dagster-aws/setup.py +++ b/python_modules/libraries/dagster-aws/setup.py @@ -37,6 +37,7 @@ def get_version() -> str: python_requires=">=3.8,<3.13", install_requires=[ "boto3", + "boto3-stubs-lite[ecs,glue,emr,emr-serverless]", f"dagster{pin}", "packaging", "requests", @@ -45,7 +46,7 @@ def get_version() -> str: "redshift": ["psycopg2-binary"], "pyspark": ["dagster-pyspark"], "stubs": [ - "boto3-stubs-lite[ecs,glue,emr-serverless,s3]", + "boto3-stubs-lite[ecs,glue,emr-serverless,s3,emr]", ], "test": [ "botocore!=1.32.1",