Skip to content

Commit

Permalink
type hint improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
kalaspuff committed Jan 28, 2024
1 parent b35fb2d commit e01d68a
Showing 1 changed file with 38 additions and 25 deletions.
63 changes: 38 additions & 25 deletions tomodachi/transport/aws_sns_sqs.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,8 +876,8 @@ async def subscribe_handler(

async def handler(
payload: Optional[str],
receipt_handle: Optional[str] = None,
queue_url: Optional[str] = None,
receipt_handle: str,
queue_url: str,
message_topic: str = "",
message_attributes: Optional[Dict] = None,
approximate_receive_count: Optional[int] = None,
Expand Down Expand Up @@ -1128,7 +1128,7 @@ async def routine_func(*a: Any, **kw: Any) -> Any:

return return_value

attributes: Dict[str, Union[str, bool]] = {}
attributes: Dict[str, str] = {}

if filter_policy != FILTER_POLICY_DEFAULT:
if filter_policy is None:
Expand Down Expand Up @@ -1674,8 +1674,11 @@ async def _send_raw_message(
return message_id

@classmethod
async def delete_message(cls, receipt_handle: Optional[str], queue_url: Optional[str], context: Dict) -> None:
async def delete_message(cls, receipt_handle: str, queue_url: str, context: Dict) -> None:
if not receipt_handle:
logging.getLogger("tomodachi.awssnssqs").warning(
"Unable to delete message [sqs] from queue without receipt handle", queue_url=queue_url
)
return

if not connector.get_client("tomodachi.sqs"):
Expand Down Expand Up @@ -1877,7 +1880,7 @@ async def _get_queue_url(
try:
async with connector("tomodachi.sqs", service_name="sqs") as client:
response = await client.get_queue_url(QueueName=_queue_name, **optional_request_parameters)
queue_url = cast(Optional[str], response.get("QueueUrl"))
queue_url = response.get("QueueUrl")
except (
botocore.exceptions.NoCredentialsError,
botocore.exceptions.PartialCredentialsError,
Expand Down Expand Up @@ -2065,7 +2068,7 @@ async def create_queue(
queue_name=logging.getLogger("tomodachi.awssnssqs")._context.get("queue_name", queue_name),
)

queue_url = ""
queue_url: Optional[str] = None
try:
async with connector("tomodachi.sqs", service_name="sqs") as client:
response = await client.get_queue_url(QueueName=queue_name)
Expand Down Expand Up @@ -2098,7 +2101,9 @@ async def create_queue(
queue_attrs["MessageRetentionPeriod"] = str(message_retention_period)
try:
async with connector("tomodachi.sqs", service_name="sqs") as client:
response = await client.create_queue(QueueName=queue_name, Attributes=queue_attrs)
response = await client.create_queue(
QueueName=queue_name, Attributes=cast(Mapping[Any, str], queue_attrs)
)
queue_url = response.get("QueueUrl")
except (
botocore.exceptions.NoCredentialsError,
Expand All @@ -2120,13 +2125,15 @@ async def create_queue(

try:
async with connector("tomodachi.sqs", service_name="sqs") as client:
response = await client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=["QueueArn"])
queue_attributes_response = await client.get_queue_attributes(
QueueUrl=queue_url, AttributeNames=["QueueArn"]
)
except botocore.exceptions.ClientError as e:
error_message = str(e)
logger.warning("Unable to get queue attributes [sqs] on AWS ({})".format(error_message))
raise AWSSNSSQSException(error_message, log_level=context.get("log_level")) from e

queue_arn = response.get("Attributes", {}).get("QueueArn")
queue_arn: Optional[str] = queue_attributes_response.get("Attributes", {}).get("QueueArn")
if not queue_arn:
error_message = "Missing ARN in response"
logger.warning("Unable to get queue attributes [sqs] on AWS ({})".format(error_message))
Expand Down Expand Up @@ -2194,7 +2201,7 @@ async def subscribe_wildcard_topic(
queue_url: str,
context: Dict,
fifo: bool,
attributes: Optional[Dict[str, Union[str, bool]]] = None,
attributes: Optional[Dict[str, str]] = None,
visibility_timeout: Optional[int] = None,
redrive_policy: Optional[Dict[str, Union[str, int]]] = None,
) -> Optional[List]:
Expand Down Expand Up @@ -2230,7 +2237,9 @@ async def subscribe_wildcard_topic(
next_token = response.get("NextToken")
topics = response.get("Topics", [])
topic_arn_list = [
t.get("TopicArn") for t in topics if t.get("TopicArn") and compiled_pattern.match(t.get("TopicArn"))
cast(str, t.get("TopicArn"))
for t in topics
if t.get("TopicArn") and compiled_pattern.match(cast(str, t.get("TopicArn")))
]

if topic_arn_list:
Expand All @@ -2257,7 +2266,7 @@ async def subscribe_topics(
queue_url: str,
context: Dict,
queue_policy: Optional[Dict] = None,
attributes: Optional[Dict[str, Union[str, bool]]] = None,
attributes: Optional[Dict[str, str]] = None,
visibility_timeout: Optional[int] = None,
redrive_policy: Optional[Dict[str, Union[str, int]]] = None,
) -> List:
Expand Down Expand Up @@ -2351,19 +2360,21 @@ async def subscribe_topics(
current_queue_attributes = queue_attributes_response.get("Attributes", {})
current_queue_policy = json.loads(current_queue_attributes.get("Policy") or "{}")
current_visibility_timeout_ = current_queue_attributes.get("VisibilityTimeout")
current_visibility_timeout = int(current_visibility_timeout_) if current_visibility_timeout_ else None
if current_queue_attributes:
current_redrive_policy = json.loads(current_queue_attributes.get("RedrivePolicy") or "{}")
current_visibility_timeout = int(current_visibility_timeout_) if current_visibility_timeout_ else None
current_message_retention_period = current_queue_attributes.get("MessageRetentionPeriod")
if current_message_retention_period:
current_message_retention_period_ = current_queue_attributes.get("MessageRetentionPeriod")
if current_message_retention_period_:
try:
current_message_retention_period = int(current_message_retention_period)
current_message_retention_period = int(current_message_retention_period_)
except ValueError:
current_message_retention_period = None
current_kms_master_key_id = current_queue_attributes.get("KmsMasterKeyId")
current_kms_data_key_reuse_period_seconds = current_queue_attributes.get("KmsDataKeyReusePeriodSeconds")
if current_kms_data_key_reuse_period_seconds:
current_kms_data_key_reuse_period_seconds = int(current_kms_data_key_reuse_period_seconds)
current_kms_data_key_reuse_period_seconds_ = current_queue_attributes.get(
"KmsDataKeyReusePeriodSeconds"
)
if current_kms_data_key_reuse_period_seconds_:
current_kms_data_key_reuse_period_seconds = int(current_kms_data_key_reuse_period_seconds_)
except botocore.exceptions.ClientError:
pass

Expand Down Expand Up @@ -2420,7 +2431,9 @@ async def subscribe_topics(

try:
async with connector("tomodachi.sqs", service_name="sqs") as sqs_client:
await sqs_client.set_queue_attributes(QueueUrl=queue_url, Attributes=queue_attributes)
await sqs_client.set_queue_attributes(
QueueUrl=queue_url, Attributes=cast(Mapping[Any, str], queue_attributes)
)
except botocore.exceptions.ClientError as e:
error_message = str(e)
logging.getLogger("tomodachi.awssnssqs").warning(
Expand Down Expand Up @@ -2543,8 +2556,8 @@ async def receive_messages() -> None:
async def _receive_wrapper() -> None:
def callback(
payload: Optional[str],
receipt_handle: Optional[str],
queue_url: Optional[str],
receipt_handle: str,
queue_url: str,
message_topic: str,
message_attributes: Dict,
approximate_receive_count: Optional[int],
Expand Down Expand Up @@ -2681,8 +2694,8 @@ async def _callback() -> None:
continue

for message in messages:
receipt_handle = message.get("ReceiptHandle")
raw_message_body = message.get("Body")
receipt_handle: str = message.get("ReceiptHandle", "")
raw_message_body = message.get("Body", "")
try:
message_body = json.loads(raw_message_body)
topic_arn = message_body.get("TopicArn")
Expand Down Expand Up @@ -2872,7 +2885,7 @@ async def setup_queue(
topic: Optional[str] = None,
queue_name: Optional[str] = None,
competing_consumer: Optional[bool] = None,
attributes: Optional[Dict[str, Union[str, bool]]] = None,
attributes: Optional[Dict[str, str]] = None,
visibility_timeout: Optional[int] = None,
dead_letter_queue_name: Optional[str] = DEAD_LETTER_QUEUE_DEFAULT,
max_receive_count: Optional[int] = MAX_RECEIVE_COUNT_DEFAULT,
Expand Down

0 comments on commit e01d68a

Please sign in to comment.