Skip to content

Commit

Permalink
updates 2024-11-10 - polars credentials issue with aws
Browse files Browse the repository at this point in the history
  • Loading branch information
CHRISCARLON committed Nov 10, 2024
1 parent 1b65c67 commit ff6545b
Showing 1 changed file with 46 additions and 15 deletions.
61 changes: 46 additions & 15 deletions analytics_platform_dagster/utils/io_manager_helper/io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,34 @@
from typing import List, Dict, Union, Any, Optional
from dagster import IOManager, OutputContext, InputContext
from botocore.exceptions import ClientError
from botocore.credentials import ReadOnlyCredentials


# Errors for IO Managers
class InvalidDataTypeError(Exception):
def __init__(self, message="Input must be a pandas DataFrame"):
super().__init__(message)


class DeltaLakeWriteError(Exception):
def __init__(self, message="Error writing to Delta Lake"):
super().__init__(message)


class DeltaLakeReadError(Exception):
def __init__(self, message="Error reading from Delta Lake"):
super().__init__(message)


class S3BucketError(Exception):
def __init__(self, message="Bucket can't be empty"):
super().__init__(message)


class S3Error(Exception):
def __init__(self, message="S3 Error") -> None:
super().__init__(message)

class CredentialsError(Exception):
"""Raised when there is an error obtaining AWS credentials"""
pass


# IO Managers
class AwsWranglerDeltaLakeIOManager(IOManager):
Expand Down Expand Up @@ -419,9 +420,8 @@ def load_partition(self, context: InputContext, batch_id: int) -> pa.Table:
class PolarsDeltaLakeIOManager(IOManager):
"""
IO manager to handle reading and writing delta lake tables to S3 using Polars.
Supports AWS credentials from environment variables.
Supports AWS credentials from Fargate task IAM role.
"""

def __init__(
self,
bucket_name: str,
Expand All @@ -433,45 +433,77 @@ def __init__(

self.bucket_name = bucket_name
self.region = region

# Default storage options
# Initialize with basic settings
self.storage_options = {
"AWS_REGION": region,
"AWS_S3_ALLOW_UNSAFE_RENAME": "true",
"AWS_ACCESS_KEY_ID": os.getenv("AWS_ACCESS_KEY_ID"),
"AWS_SECRET_ACCESS_KEY": os.getenv("AWS_SECRET_ACCESS_KEY"),
# "AWS_SESSION_TOKEN": os.getenv("AWS_SESSION_TOKEN")
}

# Get credentials from task role
self._update_credentials()

# Override defaults with provided storage options
if storage_options:
self.storage_options.update(storage_options)

def _update_credentials(self) -> None:
"""
Update storage options with credentials from Fargate task role.
Uses boto3's default credential chain which automatically handles
task role credentials in Fargate.
"""
try:
session = boto3.Session()
credentials = session.get_credentials()
if credentials is None:
raise CredentialsError("Failed to obtain AWS credentials")

# Get frozen credentials to ensure they're not refreshed mid-operation
frozen_credentials: ReadOnlyCredentials = credentials.get_frozen_credentials()

# Update storage options with current credentials
self.storage_options.update({
"AWS_ACCESS_KEY_ID": frozen_credentials.access_key,
"AWS_SECRET_ACCESS_KEY": frozen_credentials.secret_key,
})

# Only add session token if it exists
if frozen_credentials.token:
self.storage_options["AWS_SESSION_TOKEN"] = frozen_credentials.token

except Exception as e:
raise CredentialsError(f"Error getting AWS credentials: {str(e)}")

def handle_output(self, context: OutputContext, obj: pl.DataFrame) -> None:
"""Write Polars DataFrame to Delta Lake table in S3"""
if not isinstance(obj, pl.DataFrame):
raise InvalidDataTypeError()

# Refresh credentials before write operation
self._update_credentials()

table_name = context.asset_key.path[-1]
table_path = f"s3://{self.bucket_name}/{table_name}/"

# Get write mode from metadata, default to overwrite
write_option = context.definition_metadata["mode"]
write_option = context.definition_metadata.get("mode", "overwrite")

try:
obj.write_delta(
table_path,
mode=write_option, # mode (str, optional) – append (Default), overwrite, ignore, error
mode="overwrite",
overwrite_schema=True,
storage_options=self.storage_options
)
context.log.info(f"Successfully wrote data to Delta Lake table at {table_path}")

except Exception as e:
raise DeltaLakeWriteError(f"Failed to write to Delta Lake: {str(e)}")

def load_input(self, context: InputContext) -> pl.DataFrame:
"""Read Delta Lake table from S3 into Polars DataFrame"""
# Refresh credentials before read operation
self._update_credentials()

table_name = context.asset_key.path[-1]
table_path = f"s3://{self.bucket_name}/{table_name}/"

Expand All @@ -481,6 +513,5 @@ def load_input(self, context: InputContext) -> pl.DataFrame:
storage_options=self.storage_options
)
return df

except Exception as e:
raise DeltaLakeReadError(f"Failed to read from Delta Lake: {str(e)}")

0 comments on commit ff6545b

Please sign in to comment.