Skip to content

Commit

Permalink
[dagster-aws] ignore pyright errors with boto3-stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Aug 27, 2024
1 parent 7f389db commit cba2bfc
Show file tree
Hide file tree
Showing 12 changed files with 122 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from botocore.exceptions import ClientError


def get_s3_objects(s3_client: boto3.client, bucket: str, prefix: str) -> List[dict]:
def get_s3_objects(s3_client: boto3.client, bucket: str, prefix: str) -> List[dict]: # pyright: ignore (reportGeneralTypeIssues)
"""Get list of objects in the S3 bucket with the given prefix."""
try:
objects = s3_client.list_objects_v2(Bucket=bucket, Prefix=prefix)
Expand All @@ -25,7 +25,7 @@ def get_s3_objects(s3_client: boto3.client, bucket: str, prefix: str) -> List[di
return []


def delete_s3_prefix(s3_client: boto3.client, bucket: str, prefix: str) -> None:
def delete_s3_prefix(s3_client: boto3.client, bucket: str, prefix: str) -> None: # pyright: ignore (reportGeneralTypeIssues)
"""Delete all objects in the S3 bucket with the given prefix."""
objects_to_delete = get_s3_objects(s3_client, bucket, prefix)
if objects_to_delete:
Expand All @@ -39,7 +39,7 @@ def delete_s3_prefix(s3_client: boto3.client, bucket: str, prefix: str) -> None:
click.echo(f"No existing contents found in s3://{bucket}/{prefix}")


def upload_file(s3_client: boto3.client, file_path: str, bucket: str, object_name: str) -> bool:
def upload_file(s3_client: boto3.client, file_path: str, bucket: str, object_name: str) -> bool: # pyright: ignore (reportGeneralTypeIssues)
"""Upload a file to an S3 bucket."""
try:
s3_client.upload_file(file_path, bucket, object_name)
Expand All @@ -50,7 +50,10 @@ def upload_file(s3_client: boto3.client, file_path: str, bucket: str, object_nam


def update_mwaa_environment(
mwaa_client: boto3.client, environment_name: str, s3_bucket: str, s3_key: str
mwaa_client: boto3.client,
environment_name: str,
s3_bucket: str,
s3_key: str, # pyright: ignore (reportGeneralTypeIssues)
) -> None:
"""Update MWAA environment or provide instructions if it doesn't exist."""
try:
Expand All @@ -62,7 +65,7 @@ def update_mwaa_environment(
)
click.echo(f"MWAA environment {environment_name} updated successfully.")
except ClientError as e:
if e.response["Error"]["Code"] == "ResourceNotFoundException":
if e.response["Error"]["Code"] == "ResourceNotFoundException": # pyright: ignore (reportTypedDictNotRequiredAccess)
click.echo(f"MWAA environment {environment_name} not found.")
click.echo("To create a new environment, use the following information:")
click.echo(f"S3 bucket: {s3_bucket}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PipesECSClient(PipesClient, TreatAsResourceParam):

def __init__(
self,
client: Optional[boto3.client] = None,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
forward_termination: bool = True,
Expand Down Expand Up @@ -90,13 +90,13 @@ def run(
)

log_configurations = {
container["name"]: container.get("logConfiguration")
for container in task_definition_response["taskDefinition"]["containerDefinitions"]
container["name"]: container.get("logConfiguration") # pyright: ignore (reportTypedDictNotRequiredAccess)
for container in task_definition_response["taskDefinition"]["containerDefinitions"] # pyright: ignore (reportTypedDictNotRequiredAccess)
}

all_container_names = {
container["name"]
for container in task_definition_response["taskDefinition"]["containerDefinitions"]
container["name"] # pyright: ignore (reportTypedDictNotRequiredAccess)
for container in task_definition_response["taskDefinition"]["containerDefinitions"] # pyright: ignore (reportTypedDictNotRequiredAccess)
}

container_names_with_overrides = {
Expand Down Expand Up @@ -131,13 +131,13 @@ def run(
}
)

params["overrides"] = (
params["overrides"] = ( # pyright: ignore (reportGeneralTypeIssues)
overrides # assign in case overrides was created here as an empty dict
)

response = self._client.run_task(**params)

tasks: List[str] = [task["taskArn"] for task in response["tasks"]]
tasks: List[str] = [task["taskArn"] for task in response["tasks"]] # pyright: ignore (reportTypedDictNotRequiredAccess)

try:
response = self._wait_for_tasks_completion(tasks=tasks, cluster=cluster)
Expand All @@ -149,10 +149,10 @@ def run(
for container in task["containers"]:
if log_config := log_configurations.get(container["name"]):
if log_config["logDriver"] == "awslogs":
log_group = log_config["options"]["awslogs-group"]
log_group = log_config["options"]["awslogs-group"] # pyright: ignore (reportTypedDictNotRequiredAccess)

# stream name is combined from: prefix, container name, task id
log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}"
log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}" # pyright: ignore (reportTypedDictNotRequiredAccess)

if isinstance(self._message_reader, PipesCloudWatchMessageReader):
self._message_reader.consume_cloudwatch_logs(
Expand Down Expand Up @@ -201,19 +201,19 @@ def _wait_for_tasks_completion(
params["cluster"] = cluster

waiter.wait(**params)
return self._client.describe_tasks(**params)
return self._client.describe_tasks(**params) # pyright: ignore (reportReturnType)

def _terminate_tasks(
self, context: OpExecutionContext, tasks: List[str], cluster: Optional[str] = None
):
for task in tasks:
try:
self._client.stop_task(
cluster=cluster,
cluster=cluster, # pyright: ignore ()
task=task,
reason="Dagster process was interrupted",
)
except botocore.exceptions.ClientError as e:
except botocore.exceptions.ClientError as e: # pyright: ignore (reportAttributeAccessIssue)
context.log.warning(
f"[pipes] Couldn't stop ECS task {task} in cluster {cluster}:\n{e}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
forward_termination: bool = True,
):
self._client = client or boto3.client("glue")
Expand Down Expand Up @@ -134,8 +134,8 @@ def run(
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
job_name,
err.response["Error"]["Code"],
err.response["Error"]["Message"],
err.response["Error"]["Code"], # pyright: ignore (reportTypedDictNotRequiredAccess)
err.response["Error"]["Message"], # pyright: ignore (reportTypedDictNotRequiredAccess)
)
raise

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PipesLambdaClient(PipesClient, TreatAsResourceParam):

def __init__(
self,
client: Optional[boto3.client] = None,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PipesS3ContextInjector(PipesContextInjector):
"""

def __init__(self, *, bucket: str, client: boto3.client):
def __init__(self, *, bucket: str, client: boto3.client): # pyright: ignore (reportGeneralTypeIssues)
super().__init__()
self.bucket = check.str_param(bucket, "bucket")
self.client = client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
*,
interval: float = 10,
bucket: str,
client: boto3.client,
client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
log_readers: Optional[Sequence[PipesLogReader]] = None,
):
super().__init__(
Expand Down Expand Up @@ -115,7 +115,7 @@ class CloudWatchEvent(TypedDict):
class PipesCloudWatchMessageReader(PipesMessageReader):
"""Message reader that consumes AWS CloudWatch logs to read pipes messages."""

def __init__(self, client: Optional[boto3.client] = None):
def __init__(self, client: Optional[boto3.client] = None): # pyright: ignore (reportGeneralTypeIssues)
"""Args:
client (boto3.client): boto3 CloudWatch client.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def example_job():
def _is_dagster_maintained(cls) -> bool:
return True

def get_client(self) -> "botocore.client.SecretsManager":
def get_client(self) -> "botocore.client.SecretsManager": # pyright: ignore (reportAttributeAccessIssue)
return construct_secretsmanager_client(
max_attempts=self.max_attempts,
region_name=self.region_name,
Expand All @@ -75,7 +75,7 @@ def get_client(self) -> "botocore.client.SecretsManager":

@dagster_maintained_resource
@resource(SecretsManagerResource.to_config_schema())
def secretsmanager_resource(context) -> "botocore.client.SecretsManager":
def secretsmanager_resource(context) -> "botocore.client.SecretsManager": # pyright: ignore (reportAttributeAccessIssue)
"""Resource that gives access to AWS SecretsManager.
The underlying SecretsManager session is created by calling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def example_job():
def _is_dagster_maintained(cls) -> bool:
return True

def get_client(self) -> "botocore.client.ssm":
def get_client(self) -> "botocore.client.ssm": # pyright: ignore (reportAttributeAccessIssue)
return construct_ssm_client(
max_attempts=self.max_attempts,
region_name=self.region_name,
Expand All @@ -78,7 +78,7 @@ def get_client(self) -> "botocore.client.ssm":

@dagster_maintained_resource
@resource(config_schema=SSMResource.to_config_schema())
def ssm_resource(context) -> "botocore.client.ssm":
def ssm_resource(context) -> "botocore.client.ssm": # pyright: ignore (reportAttributeAccessIssue)
"""Resource that gives access to AWS Systems Manager Parameter Store.
The underlying Parameter Store session is created by calling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SimulatedTaskRun:


class LocalECSMockClient:
def __init__(self, ecs_client: boto3.client, cloudwatch_client: boto3.client):
def __init__(self, ecs_client: boto3.client, cloudwatch_client: boto3.client): # pyright: ignore (reportGeneralTypeIssues)
self.ecs_client = ecs_client
self.cloudwatch_client = cloudwatch_client

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class LocalGlueMockClient:
def __init__(
self,
aws_endpoint_url: str, # usually received from moto
s3_client: boto3.client,
glue_client: boto3.client,
s3_client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
glue_client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
pipes_messages_backend: Literal["s3", "cloudwatch"],
cloudwatch_client: Optional[boto3.client] = None,
cloudwatch_client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
):
"""This class wraps moto3 clients for S3 and Glue, and provides a way to "run" Glue jobs locally.
This is necessary because moto3 does not actually run anything when you start a Glue job, so we won't be able
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_op(file_manager: S3FileManagerResource) -> None:
# placeholder function to test resource initialization
return context.log.info("return from test_solid")

with pytest.raises(botocore.exceptions.ProfileNotFound):
with pytest.raises(botocore.exceptions.ProfileNotFound): # pyright: ignore (reportAttributeAccessIssue)
context = build_op_context(
resources={"file_manager": S3FileManagerResource(**resource_config)},
)
Expand Down
86 changes: 86 additions & 0 deletions scripts/auto_ignore_pyright_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
#!python

# this script allows automatically inserting # pyright: ignore (rule) comments for pyright errors # pyright: ignore (reportUnnecessaryTypeIgnoreComment)
# it can be used like this: make pyright | ./scripts/auto_ignore_pyright_errors.py
import re
import sys

# Regular expression to match pyright error lines
error_pattern = re.compile(
r"(?P<file_name>.*?):(?P<line_number>\d+):\d+ - error:.*?(\((?P<rules_single>[^)]+)\))?$"
)

# Additional regex to match the second line of a multi-line error message for rules
rule_line_pattern = re.compile(r"^.*\((?P<rules>[^)]+)\)$")


def main():
# Dictionary to store errors by file and line
errors = {}

# Read from standard input
previous_line = None
for stdin_line in sys.stdin:
line = stdin_line.strip()
if previous_line is None:
match = error_pattern.match(line)
if match:
file_path = match.group("file_name")
line_number = int(match.group("line_number"))
rule = match.group("rules_single")

if file_path not in errors:
errors[file_path] = {}
if line_number not in errors[file_path]:
errors[file_path][line_number] = []

# If rule is found in the first line itself, add it
if rule:
errors[file_path][line_number].append(rule)

# Store the line for potential multi-line handling
previous_line = line
else:
# Reset if not a match (should not happen normally)
previous_line = None
else:
# This line is expected to be the second line of the error message
match = rule_line_pattern.match(line)
if match:
rule = match.group("rules").strip()
# Only append if it's a valid rule
if rule and "pyright:" not in rule:
file_path = error_pattern.match(previous_line).group("file_name")
line_number = int(error_pattern.match(previous_line).group("line_number"))
errors[file_path][line_number].append(rule)

# Reset for the next error message
previous_line = None

# Process each file and add ignore comments
for file_path, lines in errors.items():
try:
with open(file_path, "r") as file:
content = file.readlines()

for line_number, rules in sorted(lines.items(), reverse=True):
# Ensure only unique rules are added
unique_rules = sorted(set(rules))
ignore_comment = f" # pyright: ignore ({', '.join(unique_rules)})\n"
content[line_number - 1] = content[line_number - 1].rstrip() + ignore_comment

with open(file_path, "w") as file:
file.writelines(content)

update_summary = {line: rules for line, rules in lines.items()}

if update_summary:
for line, rules in update_summary.items():
print(f"{file_path}:{line} - ignored {rules}") # noqa

except Exception as e:
print(f"Error processing {file_path}: {e}") # noqa


if __name__ == "__main__":
main()

0 comments on commit cba2bfc

Please sign in to comment.