Skip to content

Commit

Permalink
[dagster-aws] Pipes AWS Glue Dagster run interruption handler (#23354)
Browse files Browse the repository at this point in the history
## Summary & Motivation

I decided it was a good idea to add an automatic Glue job cleanup
handler as there is a change for Spark jobs to spend a lot of
unnecessary resources otherwise.

Had to rewrite the fake Glue clients to implement non-blocking job
execution via `subprocess.Popen`

Strong Inception vibes in this PR

## How I Tested These Changes
- [x] Added a test with subprocess interruption. The test runs Dagster's
`materialize` inside a `multiprocessing.Process`. Inside this process,
the fake Glue client runs the Glue job in a `subprocess.Popen`. Once the
Dagster process receives termination signal, the `PipesGlueClient`
invokes the fake glue client to terminate the `subprocess.Popen` "job".
We can register this call in the fake client and test for it.
  • Loading branch information
danielgafni authored Aug 8, 2024
1 parent 490ea61 commit 9fbb75e
Show file tree
Hide file tree
Showing 3 changed files with 302 additions and 69 deletions.
39 changes: 35 additions & 4 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from dagster import PipesClient
from dagster._annotations import experimental
from dagster._core.definitions.resource_annotation import TreatAsResourceParam
from dagster._core.errors import DagsterExecutionInterruptedError
from dagster._core.execution.context.compute import OpExecutionContext
from dagster._core.pipes.client import (
PipesClientCompletedInvocation,
Expand Down Expand Up @@ -344,17 +345,20 @@ class PipesGlueClient(PipesClient, TreatAsResourceParam):
message_reader (Optional[PipesMessageReader]): A message reader to use to read messages
from the glue job run. Defaults to :py:class:`PipesCloudWatchsMessageReader`.
client (Optional[boto3.client]): The boto Glue client used to launch the Glue job
forward_termination (bool): Whether to cancel the Glue job run when the Dagster process receives a termination signal.
"""

def __init__(
self,
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
forward_termination: bool = True,
):
self._client = client or boto3.client("glue")
self._context_injector = context_injector
self._message_reader = message_reader or PipesCloudWatchMessageReader()
self.forward_termination = check.bool_param(forward_termination, "forward_termination")

@classmethod
def _is_dagster_maintained(cls) -> bool:
Expand Down Expand Up @@ -435,6 +439,7 @@ def run(

try:
run_id = self._client.start_job_run(**params)["JobRunId"]

except ClientError as err:
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
Expand All @@ -448,11 +453,16 @@ def run(
log_group = response["JobRun"]["LogGroupName"]
context.log.info(f"Started AWS Glue job {job_name} run: {run_id}")

response = self._wait_for_job_run_completion(job_name, run_id)
try:
response = self._wait_for_job_run_completion(job_name, run_id)
except DagsterExecutionInterruptedError:
if self.forward_termination:
self._terminate_job_run(context=context, job_name=job_name, run_id=run_id)
raise

if response["JobRun"]["JobRunState"] == "FAILED":
if status := response["JobRun"]["JobRunState"] != "SUCCEEDED":
raise RuntimeError(
f"Glue job {job_name} run {run_id} failed:\n{response['JobRun']['ErrorMessage']}"
f"Glue job {job_name} run {run_id} completed with status {status} :\n{response['JobRun'].get('ErrorMessage')}"
)
else:
context.log.info(f"Glue job {job_name} run {run_id} completed successfully")
Expand All @@ -470,6 +480,27 @@ def run(
def _wait_for_job_run_completion(self, job_name: str, run_id: str) -> Dict[str, Any]:
while True:
response = self._client.get_job_run(JobName=job_name, RunId=run_id)
if response["JobRun"]["JobRunState"] in ["FAILED", "SUCCEEDED"]:
# https://docs.aws.amazon.com/glue/latest/dg/job-run-statuses.html
if response["JobRun"]["JobRunState"] in [
"FAILED",
"SUCCEEDED",
"STOPPED",
"TIMEOUT",
"ERROR",
]:
return response
time.sleep(5)

def _terminate_job_run(self, context: OpExecutionContext, job_name: str, run_id: str):
"""Creates a handler which will gracefully stop the Run in case of external termination.
It will stop the Glue job before doing so.
"""
context.log.warning(f"[pipes] execution interrupted, stopping Glue job run {run_id}...")
response = self._client.batch_stop_job_run(JobName=job_name, JobRunIds=[run_id])
runs = response["SuccessfulSubmissions"]
if len(runs) > 0:
context.log.warning(f"Successfully stopped Glue job run {run_id}.")
else:
context.log.warning(
f"Something went wrong during Glue job run termination: {response['errors']}"
)
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import subprocess
import sys
import tempfile
import time
from typing import Dict, Literal, Optional
import warnings
from dataclasses import dataclass
from subprocess import PIPE, Popen
from typing import Dict, List, Literal, Optional

import boto3


@dataclass
class SimulatedJobRun:
popen: Popen
job_run_id: str
log_group: str
local_script: tempfile._TemporaryFileWrapper
stopped: bool = False


class LocalGlueMockClient:
def __init__(
self,
Expand All @@ -21,15 +32,49 @@ def __init__(
to receive any Dagster messages from it.
If pipes_messages_backend is configured to be CloudWatch, it also uploads stderr and stdout logs to CloudWatch
as if this has been done by Glue.
Once the job is submitted, it is being executed in a separate process to mimic Glue behavior.
Once the job status is requested, the process is checked for its status and the result is returned.
"""
self.aws_endpoint_url = aws_endpoint_url
self.s3_client = s3_client
self.glue_client = glue_client
self.pipes_messages_backend = pipes_messages_backend
self.cloudwatch_client = cloudwatch_client

def get_job_run(self, *args, **kwargs):
return self.glue_client.get_job_run(*args, **kwargs)
self.process = None # jobs will be executed in a separate process

self._job_runs: Dict[str, SimulatedJobRun] = {} # mapping of JobRunId to SimulatedJobRun

def get_job_run(self, JobName: str, RunId: str):
# get original response
response = self.glue_client.get_job_run(JobName=JobName, RunId=RunId)

# check if status override is set
simulated_job_run = self._job_runs[RunId]

if simulated_job_run.stopped:
response["JobRun"]["JobRunState"] = "STOPPED"
return response

# check if popen has completed
if simulated_job_run.popen.poll() is not None:
simulated_job_run.popen.wait()
# check status code
if simulated_job_run.popen.returncode == 0:
response["JobRun"]["JobRunState"] = "SUCCEEDED"
else:
response["JobRun"]["JobRunState"] = "FAILED"
_, stderr = simulated_job_run.popen.communicate()
response["JobRun"]["ErrorMessage"] = stderr.decode()

# upload logs to cloudwatch
if self.pipes_messages_backend == "cloudwatch":
self._upload_logs_to_cloudwatch(RunId)
else:
response["JobRun"]["JobRunState"] = "RUNNING"

return response

def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwargs):
params = {
Expand All @@ -45,67 +90,97 @@ def start_job_run(self, JobName: str, Arguments: Optional[Dict[str, str]], **kwa
bucket = script_s3_path.split("/")[2]
key = "/".join(script_s3_path.split("/")[3:])

# load the script and execute it locally
with tempfile.NamedTemporaryFile() as f:
self.s3_client.download_file(bucket, key, f.name)

args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)

result = subprocess.run(
[sys.executable, f.name, *args],
check=False,
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
capture_output=True,
)

# mock the job run with moto
response = self.glue_client.start_job_run(**params)
job_run_id = response["JobRunId"]

job_run_response = self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)
log_group = job_run_response["JobRun"]["LogGroupName"]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
f = tempfile.NamedTemporaryFile(
delete=False
) # we will close this file later during garbage collection
# load the S3 script to a local file
self.s3_client.download_file(bucket, key, f.name)

# execute the script in a separate process
args = []
for key, val in (Arguments or {}).items():
args.append(key)
args.append(val)
popen = Popen(
[sys.executable, f.name, *args],
env={
"AWS_ENDPOINT_URL": self.aws_endpoint_url,
"TESTING_PIPES_MESSAGES_BACKEND": self.pipes_messages_backend,
},
stdout=PIPE,
stderr=PIPE,
)

# record execution metadata for later use
self._job_runs[job_run_id] = SimulatedJobRun(
popen=popen,
job_run_id=job_run_id,
log_group=self.glue_client.get_job_run(JobName=JobName, RunId=job_run_id)["JobRun"][
"LogGroupName"
],
local_script=f,
)

return response

def batch_stop_job_run(self, JobName: str, JobRunIds: List[str]):
for job_run_id in JobRunIds:
if simulated_job_run := self._job_runs.get(job_run_id):
simulated_job_run.popen.terminate()
simulated_job_run.stopped = True
self._upload_logs_to_cloudwatch(job_run_id)

def _upload_logs_to_cloudwatch(self, job_run_id: str):
log_group = self._job_runs[job_run_id].log_group
stdout, stderr = self._job_runs[job_run_id].popen.communicate()

if self.pipes_messages_backend == "cloudwatch":
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)

self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)

for line in result.stderr.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output", # yes, Glue routes stderr to /output
logStreamName=job_run_id,
logEvents=[{"timestamp": int(time.time() * 1000), "message": str(line)}],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

# replace run state with actual results
response["JobRun"] = {}

response["JobRun"]["JobRunState"] = "SUCCEEDED" if result.returncode == 0 else "FAILED"

# add error message if failed
if result.returncode != 0:
# this actually has to be just the Python exception, but this is good enough for now
response["JobRun"]["ErrorMessage"] = result.stderr
assert (
self.cloudwatch_client is not None
), "cloudwatch_client has to be provided with cloudwatch messages backend"

return response
try:
self.cloudwatch_client.create_log_group(
logGroupName=f"{log_group}/output",
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

try:
self.cloudwatch_client.create_log_stream(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
)
except self.cloudwatch_client.exceptions.ResourceAlreadyExistsException:
pass

for out in [stderr, stdout]: # Glue routes both stderr and stdout to /output
for line in out.decode().split(
"\n"
): # uploading log lines one by one is good enough for tests
if line:
self.cloudwatch_client.put_log_events(
logGroupName=f"{log_group}/output",
logStreamName=job_run_id,
logEvents=[
{"timestamp": int(time.time() * 1000), "message": str(line)}
],
)
time.sleep(
0.01
) # make sure the logs will be properly filtered by ms timestamp when accessed next time

def __del__(self):
# cleanup local script paths
for job_run in self._job_runs.values():
job_run.local_script.close()
Loading

0 comments on commit 9fbb75e

Please sign in to comment.