Skip to content

Commit

Permalink
[dagster-aws] add PipesEMRClient (#23998)
Browse files Browse the repository at this point in the history
## Summary & Motivation
This PR adds `PipesEMRCLient` to `dagster-aws`. 

It allows running Spark workloads in ephemeral EMR (EC2 flavor)
clusters.

There is no support for submitting steps to existing EMR clusters. It
turned out to be pretty hard to support properly since there is no
native way to distinguish between steps submitted by our client and
other workloads running in the same cluster. We can try to implement
this in the future.

Tasks:
- [x] `PipesEMRClient` implementation
- [x] local testing

Showcase:


![image.png](https://graphite-user-uploaded-assets-prod.s3.amazonaws.com/MvopLwsR8lA4fwP4UJdE/f9299522-d488-4adf-a549-339477fa3498.png)

## How I Tested These Changes

## Changelog [New | Bug | Docs]

[dagster-aws] new AWS EMR Dagster Pipes client
(`dagster_aws.pipes.PipesEMRCLient` ) for running and monitoring AWS EMR
jobs from Dagster.
  • Loading branch information
danielgafni authored Oct 14, 2024
1 parent 5c014ea commit 6fb82d9
Show file tree
Hide file tree
Showing 8 changed files with 362 additions and 3 deletions.
1 change: 1 addition & 0 deletions pyright/alt-1/requirements-pinned.txt
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ multidict==6.1.0
multimethod==1.10
mypy==1.11.2
mypy-boto3-ecs==1.35.21
mypy-boto3-emr==1.35.18
mypy-boto3-emr-serverless==1.35.25
mypy-boto3-glue==1.35.25
mypy-boto3-s3==1.35.32
Expand Down
1 change: 1 addition & 0 deletions pyright/master/requirements-pinned.txt
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ msgpack==1.1.0
multidict==6.1.0
multimethod==1.10
mypy-boto3-ecs==1.35.21
mypy-boto3-emr==1.35.18
mypy-boto3-emr-serverless==1.35.25
mypy-boto3-glue==1.35.25
mypy-boto3-s3==1.35.32
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dagster_aws.pipes.clients import (
PipesECSClient,
PipesEMRClient,
PipesEMRServerlessClient,
PipesGlueClient,
PipesLambdaClient,
Expand All @@ -19,6 +20,7 @@
"PipesGlueClient",
"PipesLambdaClient",
"PipesECSClient",
"PipesEMRClient",
"PipesS3ContextInjector",
"PipesLambdaEventContextInjector",
"PipesS3MessageReader",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from dagster_aws.pipes.clients.ecs import PipesECSClient
from dagster_aws.pipes.clients.emr import PipesEMRClient
from dagster_aws.pipes.clients.emr_serverless import PipesEMRServerlessClient
from dagster_aws.pipes.clients.glue import PipesGlueClient
from dagster_aws.pipes.clients.lambda_ import PipesLambdaClient

__all__ = ["PipesGlueClient", "PipesLambdaClient", "PipesECSClient", "PipesEMRServerlessClient"]
__all__ = [
"PipesGlueClient",
"PipesLambdaClient",
"PipesECSClient",
"PipesEMRServerlessClient",
"PipesEMRClient",
]
339 changes: 339 additions & 0 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,339 @@
import os
import sys
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast

import boto3
import dagster._check as check
from dagster import PipesClient
from dagster._annotations import public
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
PipesContextInjector,
PipesMessageReader,
)
from dagster._core.pipes.utils import PipesEnvContextInjector, PipesSession, open_pipes_session

from dagster_aws.emr.emr import EMR_CLUSTER_TERMINATED_STATES
from dagster_aws.pipes.message_readers import (
PipesS3LogReader,
PipesS3MessageReader,
gzip_log_decode_fn,
)

if TYPE_CHECKING:
from mypy_boto3_emr import EMRClient
from mypy_boto3_emr.literals import ClusterStateType
from mypy_boto3_emr.type_defs import (
ConfigurationUnionTypeDef,
DescribeClusterOutputTypeDef,
RunJobFlowInputRequestTypeDef,
RunJobFlowOutputTypeDef,
)


def add_configuration(
configurations: List["ConfigurationUnionTypeDef"],
configuration: "ConfigurationUnionTypeDef",
):
"""Add a configuration to a list of EMR configurations, merging configurations with the same classification.
This is necessary because EMR doesn't accept multiple configurations with the same classification.
"""
for existing_configuration in configurations:
if existing_configuration.get("Classification") is not None and existing_configuration.get(
"Classification"
) == configuration.get("Classification"):
properties = {**existing_configuration.get("Properties", {})}
properties.update(properties)

inner_configurations = cast(
List["ConfigurationUnionTypeDef"], existing_configuration.get("Configurations", [])
)

for inner_configuration in cast(
List["ConfigurationUnionTypeDef"], configuration.get("Configurations", [])
):
add_configuration(inner_configurations, inner_configuration)

existing_configuration["Properties"] = properties
existing_configuration["Configurations"] = inner_configurations # type: ignore

break
else:
configurations.append(configuration)


class PipesEMRClient(PipesClient, TreatAsResourceParam):
"""A pipes client for running jobs on AWS EMR.
Args:
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the EMR jobs.
Recommended to use :py:class:`PipesS3MessageReader` with `expect_s3_message_writer` set to `True`.
client (Optional[boto3.client]): The boto3 EMR client used to interact with AWS EMR.
context_injector (Optional[PipesContextInjector]): A context injector to use to inject
context into AWS EMR job. Defaults to :py:class:`PipesEnvContextInjector`.
forward_termination (bool): Whether to cancel the EMR job if the Dagster process receives a termination signal.
wait_for_s3_logs_seconds (int): The number of seconds to wait for S3 logs to be written after execution completes.
"""

def __init__(
self,
message_reader: PipesMessageReader,
client=None,
context_injector: Optional[PipesContextInjector] = None,
forward_termination: bool = True,
wait_for_s3_logs_seconds: int = 10,
):
self._client = client or boto3.client("emr")
self._message_reader = message_reader
self._context_injector = context_injector or PipesEnvContextInjector()
self.forward_termination = check.bool_param(forward_termination, "forward_termination")
self.wait_for_s3_logs_seconds = wait_for_s3_logs_seconds

@property
def client(self) -> "EMRClient":
return self._client

@property
def context_injector(self) -> PipesContextInjector:
return self._context_injector

@property
def message_reader(self) -> PipesMessageReader:
return self._message_reader

@classmethod
def _is_dagster_maintained(cls) -> bool:
return True

@public
def run(
self,
*,
context: OpExecutionContext,
run_job_flow_params: "RunJobFlowInputRequestTypeDef",
extras: Optional[Dict[str, Any]] = None,
) -> PipesClientCompletedInvocation:
"""Run a job on AWS EMR, enriched with the pipes protocol.
Starts a new EMR cluster for each invocation.
Args:
context (OpExecutionContext): The context of the currently executing Dagster op or asset.
run_job_flow_params (Optional[dict]): Parameters for the ``run_job_flow`` boto3 EMR client call.
See `Boto3 API Documentation <https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/emr/client/emr.html#emr>`_
extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session in the external process.
Returns:
PipesClientCompletedInvocation: Wrapper containing results reported by the external
process.
"""
with open_pipes_session(
context=context,
message_reader=self.message_reader,
context_injector=self.context_injector,
extras=extras,
) as session:
run_job_flow_params = self._enrich_params(session, run_job_flow_params)
start_response = self._start(context, session, run_job_flow_params)
try:
self._add_log_readers(context, start_response)
wait_response = self._wait_for_completion(context, start_response)
self._read_remaining_logs(context, wait_response)
return PipesClientCompletedInvocation(session)

except DagsterExecutionInterruptedError:
if self.forward_termination:
context.log.warning(
"[pipes] Dagster process interrupted! Will terminate external EMR job."
)
self._terminate(context, start_response)
raise

def _enrich_params(
self, session: PipesSession, params: "RunJobFlowInputRequestTypeDef"
) -> "RunJobFlowInputRequestTypeDef":
# add Pipes env variables
pipes_env_vars = session.get_bootstrap_env_vars()

configurations = cast(List["ConfigurationUnionTypeDef"], params.get("Configurations", []))

# add all possible env vars to spark-defaults, spark-env, yarn-env, hadoop-env
# since we can't be sure which one will be used by the job
add_configuration(
configurations,
{
"Classification": "spark-defaults",
"Properties": {
f"spark.yarn.appMasterEnv.{var}": value for var, value in pipes_env_vars.items()
},
},
)

for classification in ["spark-env", "yarn-env", "hadoop-env"]:
add_configuration(
configurations,
{
"Classification": classification,
"Configurations": [
{
"Classification": "export",
"Properties": pipes_env_vars,
}
],
},
)

params["Configurations"] = configurations

tags = list(params.get("Tags", []))

for key, value in session.default_remote_invocation_info.items():
tags.append({"Key": key, "Value": value})

params["Tags"] = tags

return params

def _start(
self,
context: OpExecutionContext,
session: PipesSession,
params: "RunJobFlowInputRequestTypeDef",
) -> "RunJobFlowOutputTypeDef":
response = self._client.run_job_flow(**params)

session.report_launched({"extras": response})

cluster_id = response["JobFlowId"]

context.log.info(f"[pipes] EMR steps started in cluster {cluster_id}")
return response

def _wait_for_completion(
self, context: OpExecutionContext, response: "RunJobFlowOutputTypeDef"
) -> "DescribeClusterOutputTypeDef":
cluster_id = response["JobFlowId"]
self._client.get_waiter("cluster_running").wait(ClusterId=cluster_id)
context.log.info(f"[pipes] EMR cluster {cluster_id} running")
# now wait for the job to complete
self._client.get_waiter("cluster_terminated").wait(ClusterId=cluster_id)

cluster = self._client.describe_cluster(ClusterId=cluster_id)

state: ClusterStateType = cluster["Cluster"]["Status"]["State"]

context.log.info(f"[pipes] EMR cluster {cluster_id} completed with state: {state}")

if state in EMR_CLUSTER_TERMINATED_STATES:
context.log.error(f"[pipes] EMR job {cluster_id} failed")
raise Exception(f"[pipes] EMR job {cluster_id} failed:\n{cluster}")

return cluster

def _add_log_readers(self, context: OpExecutionContext, response: "RunJobFlowOutputTypeDef"):
cluster = self.client.describe_cluster(ClusterId=response["JobFlowId"])

cluster_id = cluster["Cluster"]["Id"] # type: ignore
logs_uri = cluster.get("Cluster", {}).get("LogUri", {})

if isinstance(self.message_reader, PipesS3MessageReader) and logs_uri is None:
context.log.warning(
"[pipes] LogUri is not set in the EMR cluster configuration. Won't be able to read logs."
)
elif isinstance(self.message_reader, PipesS3MessageReader) and isinstance(logs_uri, str):
bucket = logs_uri.split("/")[2]
prefix = "/".join(logs_uri.split("/")[3:])

steps = self.client.list_steps(ClusterId=cluster_id)

# forward stdout and stderr from each step

for step in steps["Steps"]:
step_id = step["Id"] # type: ignore

for stdio in ["stdout", "stderr"]:
# at this stage we can't know if this key will be created
# for example, if a step doesn't have any stdout/stderr logs
# the PipesS3LogReader won't be able to start
# this may result in some unnecessary warnings
# there is not much we can do about it except perform step logs reading
# after the job is completed, which is not ideal too
key = os.path.join(prefix, f"{cluster_id}/steps/{step_id}/{stdio}.gz")

self.message_reader.add_log_reader(
log_reader=PipesS3LogReader(
client=self.message_reader.client,
bucket=bucket,
key=key,
decode_fn=gzip_log_decode_fn,
target_stream=sys.stdout if stdio == "stdout" else sys.stderr,
debug_info=f"reader for {stdio} of EMR step {step_id}",
),
)

def _read_remaining_logs(
self, context: OpExecutionContext, response: "DescribeClusterOutputTypeDef"
):
cluster_id = response["Cluster"]["Id"] # type: ignore
logs_uri = response.get("Cluster", {}).get("LogUri", {})

if isinstance(self.message_reader, PipesS3MessageReader) and isinstance(logs_uri, str):
bucket = logs_uri.split("/")[2]
prefix = "/".join(logs_uri.split("/")[3:])

# discover container (application) logs (e.g. Python logs) and forward all of them
# ex. /containers/application_1727881613116_0001/container_1727881613116_0001_01_000001/stdout.gz
containers_prefix = os.path.join(prefix, f"{cluster_id}/containers/")

context.log.debug(
f"[pipes] Waiting for {self.wait_for_s3_logs_seconds} seconds to allow EMR to dump all logs to S3. "
"Consider increasing this value if some logs are missing."
)

time.sleep(self.wait_for_s3_logs_seconds) # give EMR a chance to dump all logs to S3

context.log.debug(
f"[pipes] Looking for application logs in s3://{os.path.join(bucket, containers_prefix)}"
)

all_keys = [
obj["Key"]
for obj in self.message_reader.client.list_objects_v2(
Bucket=bucket, Prefix=containers_prefix
)["Contents"]
]

# filter keys which include stdout.gz or stderr.gz

container_log_keys = {}
for key in all_keys:
if "stdout.gz" in key:
container_log_keys[key] = "stdout"
elif "stderr.gz" in key:
container_log_keys[key] = "stderr"

# forward application logs

for key, stdio in container_log_keys.items():
container_id = key.split("/")[-2]
self.message_reader.add_log_reader(
log_reader=PipesS3LogReader(
client=self.message_reader.client,
bucket=bucket,
key=key,
decode_fn=gzip_log_decode_fn,
target_stream=sys.stdout if stdio == "stdout" else sys.stderr,
debug_info=f"log reader for container {container_id} {stdio}",
),
)

def _terminate(self, context: OpExecutionContext, start_response: "RunJobFlowOutputTypeDef"):
cluster_id = start_response["JobFlowId"]
context.log.info(f"[pipes] Terminating EMR job {cluster_id}")
self._client.terminate_job_flows(JobFlowIds=[cluster_id])
Loading

0 comments on commit 6fb82d9

Please sign in to comment.