Skip to content

Commit

Permalink
small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Jul 31, 2024
1 parent 7926060 commit 747dce0
Showing 1 changed file with 38 additions and 34 deletions.
72 changes: 38 additions & 34 deletions python_modules/libraries/dagster-aws/dagster_aws/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@
import os
import random
import string
import sys
import time
from contextlib import contextmanager
from threading import Thread
from typing import TYPE_CHECKING, Any, Dict, Iterator, Literal, Mapping, Optional, Sequence, List, Generator
from typing import TypedDict
import signal
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generator,
Iterator,
List,
Literal,
Mapping,
Optional,
Sequence,
TypedDict,
)

import boto3
import dagster._check as check
from botocore.exceptions import ClientError
Expand Down Expand Up @@ -171,9 +180,8 @@ class PipesCloudWatchMessageReader(PipesMessageReader):
"""Message reader that consumes AWS CloudWatch logs to read pipes messages."""

def __init__(self, client: Optional[boto3.client] = None):
"""
Args:
client (boto3.client): boto3 CloudWatch client.
"""Args:
client (boto3.client): boto3 CloudWatch client.
"""
self.client = client or boto3.client("logs")

Expand All @@ -190,17 +198,18 @@ def read_messages(
self._handler = None

def consume_cloudwatch_logs(
self, log_group: str, log_stream: str, start_time: Optional[int] = None, end_time: Optional[int] = None,
self,
log_group: str,
log_stream: str,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
) -> None:
handler = check.not_none(
self._handler, "Can only consume logs within context manager scope."
)

for events_batch in self._get_all_cloudwatch_events(
log_group=log_group,
log_stream=log_stream,
start_time=start_time,
end_time=end_time
log_group=log_group, log_stream=log_stream, start_time=start_time, end_time=end_time
):
for event in events_batch:
for log_line in event["message"].splitlines():
Expand All @@ -210,16 +219,14 @@ def no_messages_debug_text(self) -> str:
return "Attempted to read messages by extracting them from the tail of CloudWatch logs directly."

def _get_all_cloudwatch_events(
self,
log_group: str,
log_stream: str,
start_time: Optional[int] = None,
end_time: Optional[int] = None
) -> Generator[List[CloudWatchEvent], None, None]:
"""
Returns batches of CloudWatch events until the stream is complete or end_time.
"""
params = {
self,
log_group: str,
log_stream: str,
start_time: Optional[int] = None,
end_time: Optional[int] = None,
) -> Generator[List[CloudWatchEvent], None, None]:
"""Returns batches of CloudWatch events until the stream is complete or end_time."""
params: Dict[str, Any] = {
"logGroupName": log_group,
"logStreamName": log_stream,
}
Expand All @@ -229,18 +236,14 @@ def _get_all_cloudwatch_events(
if end_time is not None:
params["endTime"] = end_time

response = self.client.get_log_events(
**params
)
response = self.client.get_log_events(**params)

while events := response.get("events"):
yield events

params["nextToken"] = response["nextForwardToken"]

response = self.client.get_log_events(
**params
)
response = self.client.get_log_events(**params)


class PipesLambdaEventContextInjector(PipesEnvContextInjector):
Expand Down Expand Up @@ -442,7 +445,6 @@ def run(
try:
response = self._client.start_job_run(**params)


except ClientError as err:
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
Expand All @@ -454,8 +456,9 @@ def run(

run_id = response["JobRunId"]

log_group = self._client.get_job_run(JobName=job_name, RunId=run_id)["JobRun"]["LogGroupName"]
self._register_interruption_handler(context, job_name, run_id)
log_group = self._client.get_job_run(JobName=job_name, RunId=run_id)["JobRun"][
"LogGroupName"
]
context.log.info(f"Started AWS Glue job {job_name} run: {run_id}")

response = self._wait_for_job_run_completion(job_name, run_id)
Expand All @@ -469,8 +472,9 @@ def run(

if isinstance(self._message_reader, PipesGlueLogsMessageReader):
# TODO: receive messages from a background thread in real-time
self._message_reader.consume_cloudwatch_logs(f"{log_group}/output", run_id,
start_time=int(start_timestamp))
self._message_reader.consume_cloudwatch_logs(
f"{log_group}/output", run_id, start_time=int(start_timestamp)
)

return PipesClientCompletedInvocation(session)

Expand Down

0 comments on commit 747dce0

Please sign in to comment.