Skip to content

Commit

Permalink
[dagster-aws] ignore pyright errors with boto3-stubs
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Aug 27, 2024
1 parent 7f389db commit 3e312a4
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class PipesECSClient(PipesClient, TreatAsResourceParam):

def __init__(
self,
client: Optional[boto3.client] = None,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
forward_termination: bool = True,
Expand Down Expand Up @@ -90,13 +90,13 @@ def run(
)

log_configurations = {
container["name"]: container.get("logConfiguration")
for container in task_definition_response["taskDefinition"]["containerDefinitions"]
container["name"]: container.get("logConfiguration") # pyright: ignore (reportTypedDictNotRequiredAccess)
for container in task_definition_response["taskDefinition"]["containerDefinitions"] # pyright: ignore (reportTypedDictNotRequiredAccess)
}

all_container_names = {
container["name"]
for container in task_definition_response["taskDefinition"]["containerDefinitions"]
container["name"] # pyright: ignore (reportTypedDictNotRequiredAccess)
for container in task_definition_response["taskDefinition"]["containerDefinitions"] # pyright: ignore (reportTypedDictNotRequiredAccess)
}

container_names_with_overrides = {
Expand Down Expand Up @@ -131,13 +131,13 @@ def run(
}
)

params["overrides"] = (
params["overrides"] = ( # pyright: ignore (reportGeneralTypeIssues)
overrides # assign in case overrides was created here as an empty dict
)

response = self._client.run_task(**params)

tasks: List[str] = [task["taskArn"] for task in response["tasks"]]
tasks: List[str] = [task["taskArn"] for task in response["tasks"]] # pyright: ignore (reportTypedDictNotRequiredAccess)

try:
response = self._wait_for_tasks_completion(tasks=tasks, cluster=cluster)
Expand All @@ -149,10 +149,10 @@ def run(
for container in task["containers"]:
if log_config := log_configurations.get(container["name"]):
if log_config["logDriver"] == "awslogs":
log_group = log_config["options"]["awslogs-group"]
log_group = log_config["options"]["awslogs-group"] # pyright: ignore (reportTypedDictNotRequiredAccess)

# stream name is combined from: prefix, container name, task id
log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}"
log_stream = f"{log_config['options']['awslogs-stream-prefix']}/{container['name']}/{task_id}" # pyright: ignore (reportTypedDictNotRequiredAccess)

if isinstance(self._message_reader, PipesCloudWatchMessageReader):
self._message_reader.consume_cloudwatch_logs(
Expand Down Expand Up @@ -201,19 +201,19 @@ def _wait_for_tasks_completion(
params["cluster"] = cluster

waiter.wait(**params)
return self._client.describe_tasks(**params)
return self._client.describe_tasks(**params) # pyright: ignore (reportReturnType)

def _terminate_tasks(
self, context: OpExecutionContext, tasks: List[str], cluster: Optional[str] = None
):
for task in tasks:
try:
self._client.stop_task(
cluster=cluster,
cluster=cluster, # pyright: ignore ()
task=task,
reason="Dagster process was interrupted",
)
except botocore.exceptions.ClientError as e:
except botocore.exceptions.ClientError as e: # pyright: ignore (reportAttributeAccessIssue)
context.log.warning(
f"[pipes] Couldn't stop ECS task {task} in cluster {cluster}:\n{e}"
)
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def __init__(
self,
context_injector: PipesContextInjector,
message_reader: Optional[PipesMessageReader] = None,
client: Optional[boto3.client] = None,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
forward_termination: bool = True,
):
self._client = client or boto3.client("glue")
Expand Down Expand Up @@ -134,8 +134,8 @@ def run(
context.log.error(
"Couldn't create job %s. Here's why: %s: %s",
job_name,
err.response["Error"]["Code"],
err.response["Error"]["Message"],
err.response["Error"]["Code"], # pyright: ignore (reportTypedDictNotRequiredAccess)
err.response["Error"]["Message"], # pyright: ignore (reportTypedDictNotRequiredAccess)
)
raise

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class PipesLambdaClient(PipesClient, TreatAsResourceParam):

def __init__(
self,
client: Optional[boto3.client] = None,
client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
context_injector: Optional[PipesContextInjector] = None,
message_reader: Optional[PipesMessageReader] = None,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class PipesS3ContextInjector(PipesContextInjector):
"""

def __init__(self, *, bucket: str, client: boto3.client):
def __init__(self, *, bucket: str, client: boto3.client): # pyright: ignore (reportGeneralTypeIssues)
super().__init__()
self.bucket = check.str_param(bucket, "bucket")
self.client = client
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
*,
interval: float = 10,
bucket: str,
client: boto3.client,
client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
log_readers: Optional[Sequence[PipesLogReader]] = None,
):
super().__init__(
Expand Down Expand Up @@ -115,7 +115,7 @@ class CloudWatchEvent(TypedDict):
class PipesCloudWatchMessageReader(PipesMessageReader):
"""Message reader that consumes AWS CloudWatch logs to read pipes messages."""

def __init__(self, client: Optional[boto3.client] = None):
def __init__(self, client: Optional[boto3.client] = None): # pyright: ignore (reportGeneralTypeIssues)
"""Args:
client (boto3.client): boto3 CloudWatch client.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def example_job():
def _is_dagster_maintained(cls) -> bool:
return True

def get_client(self) -> "botocore.client.SecretsManager":
def get_client(self) -> "botocore.client.SecretsManager": # pyright: ignore (reportAttributeAccessIssue)
return construct_secretsmanager_client(
max_attempts=self.max_attempts,
region_name=self.region_name,
Expand All @@ -75,7 +75,7 @@ def get_client(self) -> "botocore.client.SecretsManager":

@dagster_maintained_resource
@resource(SecretsManagerResource.to_config_schema())
def secretsmanager_resource(context) -> "botocore.client.SecretsManager":
def secretsmanager_resource(context) -> "botocore.client.SecretsManager": # pyright: ignore (reportAttributeAccessIssue)
"""Resource that gives access to AWS SecretsManager.
The underlying SecretsManager session is created by calling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def example_job():
def _is_dagster_maintained(cls) -> bool:
return True

def get_client(self) -> "botocore.client.ssm":
def get_client(self) -> "botocore.client.ssm": # pyright: ignore (reportAttributeAccessIssue)
return construct_ssm_client(
max_attempts=self.max_attempts,
region_name=self.region_name,
Expand All @@ -78,7 +78,7 @@ def get_client(self) -> "botocore.client.ssm":

@dagster_maintained_resource
@resource(config_schema=SSMResource.to_config_schema())
def ssm_resource(context) -> "botocore.client.ssm":
def ssm_resource(context) -> "botocore.client.ssm": # pyright: ignore (reportAttributeAccessIssue)
"""Resource that gives access to AWS Systems Manager Parameter Store.
The underlying Parameter Store session is created by calling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SimulatedTaskRun:


class LocalECSMockClient:
def __init__(self, ecs_client: boto3.client, cloudwatch_client: boto3.client):
def __init__(self, ecs_client: boto3.client, cloudwatch_client: boto3.client): # pyright: ignore (reportGeneralTypeIssues)
self.ecs_client = ecs_client
self.cloudwatch_client = cloudwatch_client

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class LocalGlueMockClient:
def __init__(
self,
aws_endpoint_url: str, # usually received from moto
s3_client: boto3.client,
glue_client: boto3.client,
s3_client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
glue_client: boto3.client, # pyright: ignore (reportGeneralTypeIssues)
pipes_messages_backend: Literal["s3", "cloudwatch"],
cloudwatch_client: Optional[boto3.client] = None,
cloudwatch_client: Optional[boto3.client] = None, # pyright: ignore (reportGeneralTypeIssues)
):
"""This class wraps moto3 clients for S3 and Glue, and provides a way to "run" Glue jobs locally.
This is necessary because moto3 does not actually run anything when you start a Glue job, so we won't be able
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_op(file_manager: S3FileManagerResource) -> None:
# placeholder function to test resource initialization
return context.log.info("return from test_solid")

with pytest.raises(botocore.exceptions.ProfileNotFound):
with pytest.raises(botocore.exceptions.ProfileNotFound): # pyright: ignore (reportAttributeAccessIssue)
context = build_op_context(
resources={"file_manager": S3FileManagerResource(**resource_config)},
)
Expand Down
84 changes: 84 additions & 0 deletions scripts/auto_ignore_pyright_errors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#!python

# this script allows automatically inserting # pyright: ignore (rule) comments for pyright errors
# it can be used like this: make pyright | ./scripts/auto_ignore_pyright_errors.py
import re
import sys

# Regular expression to match pyright error lines
error_pattern = re.compile(
r"(?P<file_name>.*?):(?P<line_number>\d+):\d+ - error:.*?(\((?P<rules_single>[^)]+)\))?$"
)

# Additional regex to match the second line of a multi-line error message for rules
rule_line_pattern = re.compile(r"^.*\((?P<rules>[^)]+)\)$")


def main():
# Dictionary to store errors by file and line
errors = {}

# Read from standard input
previous_line = None
for stdin_line in sys.stdin:
line = stdin_line.strip()
if previous_line is None:
match = error_pattern.match(line)
if match:
file_path = match.group("file_name")
line_number = int(match.group("line_number"))
rule = match.group("rules_single")

if file_path not in errors:
errors[file_path] = {}
if line_number not in errors[file_path]:
errors[file_path][line_number] = []

# If rule is found in the first line itself, add it
if rule:
errors[file_path][line_number].append(rule)

# Store the line for potential multi-line handling
previous_line = line
else:
# Reset if not a match (should not happen normally)
previous_line = None
else:
# This line is expected to be the second line of the error message
match = rule_line_pattern.match(line)
if match:
rule = match.group("rules").strip()
# Only append if it's a valid rule
if rule and "pyright:" not in rule:
file_path = error_pattern.match(previous_line).group("file_name")
line_number = int(error_pattern.match(previous_line).group("line_number"))
errors[file_path][line_number].append(rule)

# Reset for the next error message
previous_line = None

# Process each file and add ignore comments
for file_path, lines in errors.items():
try:
with open(file_path, "r") as file:
content = file.readlines()

for line_number, rules in sorted(lines.items(), reverse=True):
# Ensure only unique rules are added
unique_rules = sorted(set(rules))
ignore_comment = f" # pyright: ignore ({', '.join(unique_rules)})\n"
content[line_number - 1] = content[line_number - 1].rstrip() + ignore_comment

with open(file_path, "w") as file:
file.writelines(content)

update_summary = {line: rules for line, rules in lines.items()}

print(f"Updated {file_path} with ignore comments: {update_summary}") # noqa

except Exception as e:
print(f"Error processing {file_path}: {e}") # noqa


if __name__ == "__main__":
main()

0 comments on commit 3e312a4

Please sign in to comment.