From f4cff858595d4597f9dbe633d5d7d25db90e2eb6 Mon Sep 17 00:00:00 2001 From: Daniel Gafni Date: Fri, 13 Sep 2024 11:28:25 +0300 Subject: [PATCH] [dagster-aws] add `PipesEMRServerlessClient` (#24318) ## Summary & Motivation resolve [DS-443](https://linear.app/dagster-labs/issue/DS-443/implement-pipes-emr-serverless) Implemented message reading for CloudWatch logging configuration (using `CloudWatchMessageReader`). Mapping Spark driver stdout/stderr to Dagster's stdout/stderr. S3-backed logs are currently blocked by https://github.com/dagster-io/dagster/pull/24098 ## How I Tested These Changes 1. Manually tested with real AWS infra (personal account) 2. Added a very simple test which checks that context env vars are set correctly. This is as far as we can get with `moto` since they don't have `start_job_run` implemented. I'm planning to add integration tests in CI with real AWS infra. ## Changelog Insert changelog entry or "NOCHANGELOG" here. - [x] `NEW` _([dagster-aws] added `PipesEMRServerlessClient`, but it's untested yet so please don't include into release notes until we have tests)_ - [ ] `BUGFIX` _(fixed a bug)_ - [ ] `DOCS` _(added or updated documentation)_ --- docs/content/api/modules.json.gz | Bin 1321179 -> 1321179 bytes docs/content/api/searchindex.json.gz | Bin 80883 -> 80883 bytes docs/content/api/sections.json.gz | Bin 466546 -> 466546 bytes pyright/alt-1/requirements-pinned.txt | 7 +- pyright/master/requirements-pinned.txt | 19 +- .../dagster-aws/dagster_aws/pipes/__init__.py | 8 +- .../dagster_aws/pipes/clients/__init__.py | 3 +- .../pipes/clients/emr_serverless.py | 304 ++++++++++++++++++ .../pipes_tests/test_pipes.py | 75 ++++- .../libraries/dagster-aws/ruff.toml | 1 + python_modules/libraries/dagster-aws/setup.py | 4 +- 11 files changed, 404 insertions(+), 17 deletions(-) create mode 100644 python_modules/libraries/dagster-aws/dagster_aws/pipes/clients/emr_serverless.py diff --git a/docs/content/api/modules.json.gz b/docs/content/api/modules.json.gz index 20a84ccb2c8c881d799223c1d641fc86680dc036..249c79e11d6bb63776eb9309ab55386f6f9ec1ef 100644 GIT binary patch delta 78 zcmWN_sSSis06@Vk3%t*LdMvHrXbJHY-H=cS$vWgV0wd?X*85tS_7INczabxq;rm9!oh`%87W2dMJdMI2>KX!OFeP^*UGPJ%sIlf00OHNhFm_ Za(OdIA*EFEky;vQrITI;84sr4@&K2R6l4GZ diff --git a/docs/content/api/searchindex.json.gz b/docs/content/api/searchindex.json.gz index 238bda0393437c5a3e89280559ae0b8a9293ac72..1c0f33537da4078356d83d4279b472f06afdaa85 100644 GIT binary patch delta 21 dcmezTo#pd)7Iyh=4vx*9j~m&yvNOK11^{hN2 "EMRServerlessClient": + return self._client + + @property + def context_injector(self) -> PipesContextInjector: + return self._context_injector + + @property + def message_reader(self) -> PipesMessageReader: + return self._message_reader + + @classmethod + def _is_dagster_maintained(cls) -> bool: + return True + + @public + def run( + self, + *, + context: OpExecutionContext, + start_job_run_params: "StartJobRunRequestRequestTypeDef", + extras: Optional[Dict[str, Any]] = None, + ) -> PipesClientCompletedInvocation: + """Run a workload on AWS EMR Serverless, enriched with the pipes protocol. + + Args: + context (OpExecutionContext): The context of the currently executing Dagster op or asset. + params (dict): Parameters for the ``start_job_run`` boto3 AWS EMR Serverless client call. + See `Boto3 API Documentation `_ + extras (Optional[Dict[str, Any]]): Additional information to pass to the Pipes session in the external process. + + Returns: + PipesClientCompletedInvocation: Wrapper containing results reported by the external + process. + """ + with open_pipes_session( + context=context, + message_reader=self.message_reader, + context_injector=self.context_injector, + extras=extras, + ) as session: + start_job_run_params = self._enrich_start_params(context, session, start_job_run_params) + start_response = self._start(context, start_job_run_params) + try: + completion_response = self._wait_for_completion(context, start_response) + context.log.info(f"[pipes] {self.AWS_SERVICE_NAME} workload is complete!") + self._read_messages(context, completion_response) + return PipesClientCompletedInvocation(session) + + except DagsterExecutionInterruptedError: + if self.forward_termination: + context.log.warning( + f"[pipes] Dagster process interrupted! Will terminate external {self.AWS_SERVICE_NAME} workload." + ) + self._terminate(context, start_response) + raise + + def _enrich_start_params( + self, + context: OpExecutionContext, + session: PipesSession, + params: "StartJobRunRequestRequestTypeDef", + ) -> "StartJobRunRequestRequestTypeDef": + # inject Dagster tags + tags = params.get("tags", {}) + tags = { + **tags, + "dagster/run_id": context.run_id, + } + + params["tags"] = tags + # inject env variables via --conf spark.executorEnv.env.= + + dagster_env_vars = {} + + dagster_env_vars.update(session.get_bootstrap_env_vars()) + + if "jobDriver" not in params: + params["jobDriver"] = {} + + if "sparkSubmit" not in params["jobDriver"]: + params["jobDriver"]["sparkSubmit"] = {} # pyright: ignore[reportGeneralTypeIssues] + + params["jobDriver"]["sparkSubmit"]["sparkSubmitParameters"] = params.get( + "jobDriver", {} + ).get("sparkSubmit", {}).get("sparkSubmitParameters", "") + "".join( + [ + f" --conf spark.emr-serverless.driverEnv.{key}={value}" + for key, value in dagster_env_vars.items() + ] + ) + + return cast("StartJobRunRequestRequestTypeDef", params) + + def _start( + self, context: OpExecutionContext, params: "StartJobRunRequestRequestTypeDef" + ) -> "StartJobRunResponseTypeDef": + response = self.client.start_job_run(**params) + job_run_id = response["jobRunId"] + context.log.info( + f"[pipes] {self.AWS_SERVICE_NAME} job started with job_run_id {job_run_id}" + ) + return response + + def _wait_for_completion( + self, context: OpExecutionContext, start_response: "StartJobRunResponseTypeDef" + ) -> "GetJobRunResponseTypeDef": # pyright: ignore[reportReturnType] + job_run_id = start_response["jobRunId"] + + while response := self.client.get_job_run( + applicationId=start_response["applicationId"], + jobRunId=job_run_id, + ): + state: "JobRunStateType" = response["jobRun"]["state"] + + if state in ["FAILED", "CANCELLED", "CANCELLING"]: + context.log.error( + f"[pipes] {self.AWS_SERVICE_NAME} job {job_run_id} terminated with state: {state}. Details:\n{response['jobRun'].get('stateDetails')}" + ) + raise RuntimeError( + f"{self.AWS_SERVICE_NAME} job failed" + ) # TODO: introduce something like DagsterPipesRemoteExecutionError + elif state == "SUCCESS": + context.log.info( + f"[pipes] {self.AWS_SERVICE_NAME} job {job_run_id} completed with state: {state}" + ) + return response + elif state in ["PENDING", "SUBMITTED", "SCHEDULED", "RUNNING"]: + time.sleep(self.poll_interval) + continue + else: + raise DagsterInvariantViolationError( + f"Unexpected state for AWS EMR Serverless job {job_run_id}: {state}" + ) + + def _read_messages(self, context: OpExecutionContext, response: "GetJobRunResponseTypeDef"): + application_id = response["jobRun"]["applicationId"] + job_id = response["jobRun"]["jobRunId"] + + application = self.client.get_application(applicationId=application_id)["application"] + + # merge base monitoring configuration from application + # with potential overrides from the job run + application_monitoring_configuration = application.get("monitoringConfiguration", {}) + job_monitoring_configuration = ( + response["jobRun"].get("configurationOverrides", {}).get("monitoringConfiguration", {}) + ) + monitoring_configuration = cast( + "MonitoringConfigurationTypeDef", + deep_merge_dicts(application_monitoring_configuration, job_monitoring_configuration), + ) + + application_type = application["type"] + + if application_type == "Spark": + worker_type = "SPARK_DRIVER" + elif application_type == "Hive": + worker_type = "HIVE_DRIVER" + else: + raise NotImplementedError(f"Application type {application_type} is not supported") + + if not isinstance(self.message_reader, PipesCloudWatchMessageReader): + context.log.warning( + f"[pipes] {self.message_reader} is not supported for {self.AWS_SERVICE_NAME}. Dagster won't be able to receive logs and messages from the job." + ) + return + + # https://docs.aws.amazon.com/emr/latest/EMR-Serverless-UserGuide/logging.html#jobs-log-storage-cw + + # we can get cloudwatch logs from the known log group + + if ( + monitoring_configuration.get("cloudWatchLoggingConfiguration", {}).get("enabled") + is not True + ): + context.log.warning( + f"[pipes] Recieved {self.message_reader}, but CloudWatch logging is not enabled for {self.AWS_SERVICE_NAME} job. Dagster won't be able to receive logs and messages from the job." + ) + return + + if log_types := monitoring_configuration.get("cloudWatchLoggingConfiguration", {}).get( + "logTypes" + ): + # get the configured output streams + # but limit them with "stdout" and "stderr" + output_streams = list( + map( + lambda x: x.lower(), + set(log_types.get(worker_type, ["STDOUT", "STDERR"])) & {"stdout", "stderr"}, + ) + ) + else: + output_streams = ["stdout", "stderr"] + + log_group = monitoring_configuration.get("logGroupName") or "/aws/emr-serverless" + + attempt = response["jobRun"].get("attempt") + + if attempt is not None and attempt > 1: + log_stream = ( + f"/applications/{application_id}/jobs/{job_id}/attempts/{attempt}/{worker_type}" + ) + else: + log_stream = f"/applications/{application_id}/jobs/{job_id}/{worker_type}" + + if log_stream_prefix := monitoring_configuration.get( + "cloudWatchLoggingConfiguration", {} + ).get("logStreamNamePrefix"): + log_stream = f"{log_stream_prefix}{log_stream}" + + output_files = { + "stdout": sys.stdout, + "stderr": sys.stderr, + } + + # TODO: do this in a background thread in real-time once https://github.com/dagster-io/dagster/pull/24098 is merged + for output_stream in output_streams: + output_file = output_files[output_stream] + context.log.debug( + f"[pipes] Reading AWS CloudWatch logs from group {log_group} stream {log_stream}/{output_stream}" + ) + self.message_reader.consume_cloudwatch_logs( + log_group, + f"{log_stream}/{output_stream}", + start_time=int( + response["jobRun"] + .get("attemptCreatedAt", response["jobRun"]["createdAt"]) + .timestamp() + * 1000 + ), + output_file=output_file, + ) + + def _terminate(self, context: OpExecutionContext, start_response: "StartJobRunResponseTypeDef"): + job_run_id = start_response["jobRunId"] + application_id = start_response["applicationId"] + context.log.info(f"[pipes] Terminating {self.AWS_SERVICE_NAME} job run {job_run_id}") + self.client.cancel_job_run(applicationId=application_id, jobRunId=job_run_id) diff --git a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py index 1b8067d312c60..f2f0995bb88c0 100644 --- a/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py +++ b/python_modules/libraries/dagster-aws/dagster_aws_tests/pipes_tests/test_pipes.py @@ -10,7 +10,8 @@ import time from contextlib import contextmanager from tempfile import NamedTemporaryFile -from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal +from typing import TYPE_CHECKING, Any, Callable, Iterator, Literal, Tuple +from uuid import uuid4 import boto3 import pytest @@ -32,6 +33,7 @@ from dagster_aws.pipes import ( PipesCloudWatchMessageReader, PipesECSClient, + PipesEMRServerlessClient, PipesGlueClient, PipesLambdaClient, PipesLambdaLogsMessageReader, @@ -48,6 +50,7 @@ if TYPE_CHECKING: from mypy_boto3_ecs import ECSClient + from mypy_boto3_emr_serverless import EMRServerlessClient _PYTHON_EXECUTABLE = shutil.which("python") or "python" @@ -774,3 +777,73 @@ def materialize_asset(env, return_dict): # breakpoint() assert return_dict[0]["tasks"][0]["containers"][0]["exitCode"] == 1 assert return_dict[0]["tasks"][0]["stoppedReason"] == "Dagster process was interrupted" + + +EMR_SERVERLESS_APP_NAME = "Example" + + +@pytest.fixture +def emr_serverless_setup( + moto_server, external_s3_glue_script, s3_client +) -> Tuple["EMRServerlessClient", str]: + client = boto3.client("emr-serverless", region_name="us-east-1", endpoint_url=_MOTO_SERVER_URL) + resp = client.create_application( + type="SPARK", + releaseLabel="emr-7.2.0-latest", + clientToken=str(uuid4()), + ) + return client, resp["applicationId"] + + +def test_emr_serverless_manual(emr_serverless_setup: Tuple["EMRServerlessClient", str]): + client, application_id = emr_serverless_setup + + @asset + def my_asset(context: AssetExecutionContext, emr_serverless_client: PipesEMRServerlessClient): + message_reader = PipesCloudWatchMessageReader() + context_injector = PipesEnvContextInjector() + + with open_pipes_session( + context=context, + message_reader=message_reader, + context_injector=context_injector, + ) as session: + params = emr_serverless_client._enrich_start_params( # noqa: SLF001 + context=context, + session=session, + params={ + "applicationId": application_id, + "executionRoleArn": "arn:aws:iam::123456789012:role/EMRServerlessRole", + "jobDriver": { + "sparkSubmit": { + "entryPoint": "s3://my-bucket/my-script.py", + } + }, + "clientToken": str(uuid4()), + }, + ) + + assert params["tags"]["dagster/run_id"] == context.run_id # pyright: ignore[reportTypedDictNotRequiredAccess] + assert ( + "--conf spark.emr-serverless.driverEnv.DAGSTER_PIPES_CONTEXT=" + in params["jobDriver"]["sparkSubmit"]["sparkSubmitParameters"] # pyright: ignore[reportTypedDictNotRequiredAccess] + ) + assert ( + "--conf spark.emr-serverless.driverEnv.DAGSTER_PIPES_MESSAGES=" + in params["jobDriver"]["sparkSubmit"]["sparkSubmitParameters"] # pyright: ignore[reportTypedDictNotRequiredAccess] + ) + + # moto doesn't have start_job_run implemented so this is as far as we can get with it right now + + return session.get_results() + + with instance_for_test() as instance: + materialize( + [my_asset], + resources={ + "emr_serverless_client": PipesEMRServerlessClient( + client=client, + ) + }, + instance=instance, + ) diff --git a/python_modules/libraries/dagster-aws/ruff.toml b/python_modules/libraries/dagster-aws/ruff.toml index 6fd266058fb41..244f818c1fc43 100644 --- a/python_modules/libraries/dagster-aws/ruff.toml +++ b/python_modules/libraries/dagster-aws/ruff.toml @@ -10,4 +10,5 @@ extend-select = [ banned-module-level-imports = [ "mypy_boto3_ecs", "mypy_boto3_glue", + "mypy_boto3_emr_serverless" ] diff --git a/python_modules/libraries/dagster-aws/setup.py b/python_modules/libraries/dagster-aws/setup.py index 78d03ffba1b9d..3616471fc37bf 100644 --- a/python_modules/libraries/dagster-aws/setup.py +++ b/python_modules/libraries/dagster-aws/setup.py @@ -45,11 +45,11 @@ def get_version() -> str: "redshift": ["psycopg2-binary"], "pyspark": ["dagster-pyspark"], "stubs": [ - "boto3-stubs-lite[ecs,glue]", + "boto3-stubs-lite[ecs,glue,emr-serverless]", ], "test": [ "botocore!=1.32.1", - "moto[s3,server,glue]>=2.2.8,<5.0", + "moto[s3,server,glue,emrserverless]>=2.2.8,<5.0", "requests-mock", "xmltodict==0.12.0", # pinned until moto>=3.1.9 (https://github.com/spulec/moto/issues/5112) ],