Skip to content

Commit 6a85f06

Browse files
Disambiguate default_factory
Pyright's strict mode is becoming increasingly strict with regards to generics. You used to be able to effectively narrow them by declaring the type like so: string_list: list[str] = list() But that's no longer accepted. Now the actual value must itself resolve to something that has resolved parameterization. This unfortunately means that actual casting is sometimes necessary. It's not a huge hit because it's a do-nothing function, but it's still a function call that I'd rather not be doing.
1 parent fe535e9 commit 6a85f06

File tree

8 files changed

+74
-42
lines changed

8 files changed

+74
-42
lines changed

codegen/core/src/main/java/software/amazon/smithy/python/codegen/ClientGenerator.java

+16-9
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ await sleep(retry_token.retry_delay)
542542
""");
543543

544544
writer.pushState(new SignRequestSection());
545+
writer.addStdlibImport("typing", "cast");
545546
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
546547
writer.addStdlibImport("re");
547548
writer.addStdlibImport("typing", "Any");
@@ -571,9 +572,12 @@ await sleep(retry_token.retry_delay)
571572
signature = re.split("Signature=", auth_value)[-1] # type: ignore
572573
context.properties["signature"] = signature.encode('utf-8')
573574
574-
identity_key: PropertyKey[Identity | None] = PropertyKey(
575-
key="identity",
576-
value_type=Identity | None # type: ignore
575+
identity_key = cast(
576+
PropertyKey[Identity | None],
577+
PropertyKey(
578+
key="identity",
579+
value_type=Identity | None # type: ignore
580+
)
577581
)
578582
sp_key: PropertyKey[dict[str, Any]] = PropertyKey(
579583
key="signer_properties",
@@ -665,12 +669,15 @@ await sleep(retry_token.retry_delay)
665669
# Step 7r: Invoke read_after_deserialization
666670
interceptor.read_after_deserialization(output_context)
667671
except Exception as e:
668-
output_context: OutputContext[Input, Output, $1T, $2T] = OutputContext(
669-
request=context.request,
670-
response=e, # type: ignore
671-
transport_request=context.transport_request,
672-
transport_response=transport_response,
673-
properties=context.properties
672+
output_context = cast(
673+
OutputContext[Input, Output, $1T, $2T],
674+
OutputContext(
675+
request=context.request,
676+
response=e, # type: ignore
677+
transport_request=context.transport_request,
678+
transport_response=transport_response,
679+
properties=context.properties
680+
)
674681
)
675682
676683
return await self._finalize_attempt(interceptor, output_context)

codegen/core/src/main/java/software/amazon/smithy/python/codegen/generators/StructureGenerator.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ private void writeProperties() {
215215
if (member.hasTrait(DefaultTrait.class)) {
216216

217217
defaultValue = getDefaultValue(writer, member);
218-
if (target.isDocumentShape() || Set.of("list", "dict").contains(defaultValue)) {
218+
if (target.isDocumentShape() || defaultValue.startsWith("list[") || defaultValue.startsWith("dict[")) {
219219
writer.addStdlibImport("dataclasses", "field");
220220
defaultKey = "default_factory";
221221
requiresField = true;
@@ -308,8 +308,7 @@ private String getDefaultValue(PythonWriter writer, MemberShape member) {
308308
case BOOLEAN -> defaultNode.expectBooleanNode().getValue() ? "True" : "False";
309309
// These will be given to a default_factory in field. They're inherently empty, so no need to
310310
// worry about any potential values.
311-
case ARRAY -> "list";
312-
case OBJECT -> "dict";
311+
case ARRAY, OBJECT -> symbolProvider.toSymbol(target).getName();
313312
default -> Node.printJson(defaultNode);
314313
};
315314
}

packages/smithy-aws-event-stream/src/smithy_aws_event_stream/events.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ class EventMessage:
187187
message.
188188
"""
189189

190-
headers: HEADERS_DICT = field(default_factory=dict)
190+
headers: HEADERS_DICT = field(default_factory=dict[str, HEADER_VALUE])
191191
"""The headers present in the event message.
192192
193193
Sized integer values may be indicated for the purpose of serialization

packages/smithy-core/src/smithy_core/schemas.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,10 @@ class Schema:
2020

2121
id: ShapeID
2222
shape_type: ShapeType
23-
traits: dict[ShapeID, "Trait | DynamicTrait"] = field(default_factory=dict)
24-
members: dict[str, "Schema"] = field(default_factory=dict)
23+
traits: dict[ShapeID, "Trait | DynamicTrait"] = field(
24+
default_factory=dict[ShapeID, "Trait | DynamicTrait"]
25+
)
26+
members: dict[str, "Schema"] = field(default_factory=dict[str, "Schema"])
2527
member_target: "Schema | None" = None
2628
member_index: int | None = None
2729

packages/smithy-core/tests/unit/test_types.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# pyright: reportPrivateUsage=false
55
from datetime import UTC, datetime
6-
from typing import Any, assert_type
6+
from typing import Any, assert_type, cast
77

88
import pytest
99
from smithy_core.exceptions import ExpectationNotMetError
@@ -323,9 +323,12 @@ def test_properties_typed_pop() -> None:
323323

324324
def test_union_property() -> None:
325325
properties = TypedProperties()
326-
union: PropertyKey[str | int] = PropertyKey(
327-
key="union",
328-
value_type=str | int, # type: ignore
326+
union = cast(
327+
PropertyKey[str | int],
328+
PropertyKey(
329+
key="union",
330+
value_type=str | int, # type: ignore
331+
),
329332
)
330333

331334
properties[union] = "foo"

packages/smithy-http/src/smithy_http/aio/crt.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from copy import deepcopy
1111
from functools import partial
1212
from io import BufferedIOBase, BytesIO
13-
from typing import TYPE_CHECKING, Any
13+
from typing import TYPE_CHECKING, Any, cast
1414

1515
if TYPE_CHECKING:
1616
# Both of these are types that essentially are "castable to bytes/memoryview"
@@ -119,7 +119,7 @@ def set_stream(self, stream: "crt_http.HttpClientStream") -> None:
119119
if self._stream is not None:
120120
raise SmithyHTTPError("Stream already set on AWSCRTHTTPResponse object")
121121
self._stream = stream
122-
concurrent_future: ConcurrentFuture[int] = stream.completion_future
122+
concurrent_future = cast(ConcurrentFuture[int], stream.completion_future)
123123
self._completion_future = asyncio.wrap_future(concurrent_future)
124124
self._completion_future.add_done_callback(self._on_complete)
125125
self._stream.activate()
@@ -305,14 +305,15 @@ def _build_new_connection(
305305
if url.port is not None:
306306
port = url.port
307307

308-
connect_future: ConcurrentFuture[crt_http.HttpClientConnection] = (
308+
connect_future = cast(
309+
ConcurrentFuture[crt_http.HttpClientConnection],
309310
crt_http.HttpClientConnection.new(
310311
bootstrap=self._client_bootstrap,
311312
host_name=url.host,
312313
port=port,
313314
socket_options=self._socket_options,
314315
tls_connection_options=tls_connection_options,
315-
)
316+
),
316317
)
317318
return connect_future
318319

packages/smithy-http/src/smithy_http/user_agent.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,35 @@ def __str__(self) -> str:
5050

5151
@dataclass(kw_only=True, slots=True)
5252
class UserAgent:
53-
sdk_metadata: list[UserAgentComponent] = field(default_factory=list)
54-
internal_metadata: list[UserAgentComponent] = field(default_factory=list)
55-
ua_metadata: list[UserAgentComponent] = field(default_factory=list)
56-
api_metadata: list[UserAgentComponent] = field(default_factory=list)
57-
os_metadata: list[UserAgentComponent] = field(default_factory=list)
58-
language_metadata: list[UserAgentComponent] = field(default_factory=list)
59-
env_metadata: list[UserAgentComponent] = field(default_factory=list)
60-
config_metadata: list[UserAgentComponent] = field(default_factory=list)
61-
feat_metadata: list[UserAgentComponent] = field(default_factory=list)
53+
sdk_metadata: list[UserAgentComponent] = field(
54+
default_factory=list[UserAgentComponent]
55+
)
56+
internal_metadata: list[UserAgentComponent] = field(
57+
default_factory=list[UserAgentComponent]
58+
)
59+
ua_metadata: list[UserAgentComponent] = field(
60+
default_factory=list[UserAgentComponent]
61+
)
62+
api_metadata: list[UserAgentComponent] = field(
63+
default_factory=list[UserAgentComponent]
64+
)
65+
os_metadata: list[UserAgentComponent] = field(
66+
default_factory=list[UserAgentComponent]
67+
)
68+
language_metadata: list[UserAgentComponent] = field(
69+
default_factory=list[UserAgentComponent]
70+
)
71+
env_metadata: list[UserAgentComponent] = field(
72+
default_factory=list[UserAgentComponent]
73+
)
74+
config_metadata: list[UserAgentComponent] = field(
75+
default_factory=list[UserAgentComponent]
76+
)
77+
feat_metadata: list[UserAgentComponent] = field(
78+
default_factory=list[UserAgentComponent]
79+
)
6280
additional_metadata: list[UserAgentComponent | RawStringUserAgentComponent] = field(
63-
default_factory=list
81+
default_factory=list[UserAgentComponent | RawStringUserAgentComponent]
6482
)
6583

6684
def __str__(self) -> str:

packages/smithy-http/tests/unit/test_serializers.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -124,27 +124,29 @@
124124
@dataclass
125125
class _HTTPMapping(Protocol):
126126
boolean_member: bool | None = None
127-
boolean_list_member: list[bool] = field(default_factory=list)
127+
boolean_list_member: list[bool] = field(default_factory=list[bool])
128128
integer_member: int | None = None
129-
integer_list_member: list[int] = field(default_factory=list)
129+
integer_list_member: list[int] = field(default_factory=list[int])
130130
float_member: float | None = None
131-
float_list_member: list[float] = field(default_factory=list)
131+
float_list_member: list[float] = field(default_factory=list[float])
132132
big_decimal_member: Decimal | None = None
133-
big_decimal_list_member: list[Decimal] = field(default_factory=list)
133+
big_decimal_list_member: list[Decimal] = field(default_factory=list[Decimal])
134134
string_member: str | None = None
135-
string_list_member: list[str] = field(default_factory=list)
135+
string_list_member: list[str] = field(default_factory=list[str])
136136
default_timestamp_member: datetime.datetime | None = None
137137
http_date_timestamp_member: datetime.datetime | None = None
138138
http_date_list_timestamp_member: list[datetime.datetime] = field(
139-
default_factory=list
139+
default_factory=list[datetime.datetime]
140140
)
141141
date_time_timestamp_member: datetime.datetime | None = None
142142
date_time_list_timestamp_member: list[datetime.datetime] = field(
143-
default_factory=list
143+
default_factory=list[datetime.datetime]
144144
)
145145
epoch_timestamp_member: datetime.datetime | None = None
146-
epoch_list_timestamp_member: list[datetime.datetime] = field(default_factory=list)
147-
string_map_member: dict[str, str] = field(default_factory=dict)
146+
epoch_list_timestamp_member: list[datetime.datetime] = field(
147+
default_factory=list[datetime.datetime]
148+
)
149+
string_map_member: dict[str, str] = field(default_factory=dict[str, str])
148150

149151
ID: ClassVar[ShapeID]
150152
SCHEMA: ClassVar[Schema]

0 commit comments

Comments
 (0)