Skip to content

Disambiguate generics with unknown parametric assignments #492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ await sleep(retry_token.retry_delay)
""");

writer.pushState(new SignRequestSection());
writer.addStdlibImport("typing", "cast");
if (context.applicationProtocol().isHttpProtocol() && supportsAuth) {
writer.addStdlibImport("re");
writer.addStdlibImport("typing", "Any");
Expand Down Expand Up @@ -571,9 +572,12 @@ await sleep(retry_token.retry_delay)
signature = re.split("Signature=", auth_value)[-1] # type: ignore
context.properties["signature"] = signature.encode('utf-8')

identity_key: PropertyKey[Identity | None] = PropertyKey(
key="identity",
value_type=Identity | None # type: ignore
identity_key = cast(
PropertyKey[Identity | None],
PropertyKey(
key="identity",
value_type=Identity | None # type: ignore
)
)
sp_key: PropertyKey[dict[str, Any]] = PropertyKey(
key="signer_properties",
Expand Down Expand Up @@ -665,12 +669,15 @@ await sleep(retry_token.retry_delay)
# Step 7r: Invoke read_after_deserialization
interceptor.read_after_deserialization(output_context)
except Exception as e:
output_context: OutputContext[Input, Output, $1T, $2T] = OutputContext(
request=context.request,
response=e, # type: ignore
transport_request=context.transport_request,
transport_response=transport_response,
properties=context.properties
output_context = cast(
OutputContext[Input, Output, $1T, $2T],
OutputContext(
request=context.request,
response=e, # type: ignore
transport_request=context.transport_request,
transport_response=transport_response,
properties=context.properties
)
)

return await self._finalize_attempt(interceptor, output_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ private void writeProperties() {
if (member.hasTrait(DefaultTrait.class)) {

defaultValue = getDefaultValue(writer, member);
if (target.isDocumentShape() || Set.of("list", "dict").contains(defaultValue)) {
if (target.isDocumentShape() || defaultValue.startsWith("list[") || defaultValue.startsWith("dict[")) {
writer.addStdlibImport("dataclasses", "field");
defaultKey = "default_factory";
requiresField = true;
Expand Down Expand Up @@ -308,8 +308,7 @@ private String getDefaultValue(PythonWriter writer, MemberShape member) {
case BOOLEAN -> defaultNode.expectBooleanNode().getValue() ? "True" : "False";
// These will be given to a default_factory in field. They're inherently empty, so no need to
// worry about any potential values.
case ARRAY -> "list";
case OBJECT -> "dict";
case ARRAY, OBJECT -> symbolProvider.toSymbol(target).getName();
default -> Node.printJson(defaultNode);
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ class EventMessage:
message.
"""

headers: HEADERS_DICT = field(default_factory=dict)
headers: HEADERS_DICT = field(default_factory=dict[str, HEADER_VALUE])
"""The headers present in the event message.

Sized integer values may be indicated for the purpose of serialization
Expand Down
6 changes: 4 additions & 2 deletions packages/smithy-core/src/smithy_core/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ class Schema:

id: ShapeID
shape_type: ShapeType
traits: dict[ShapeID, "Trait | DynamicTrait"] = field(default_factory=dict)
members: dict[str, "Schema"] = field(default_factory=dict)
traits: dict[ShapeID, "Trait | DynamicTrait"] = field(
default_factory=dict[ShapeID, "Trait | DynamicTrait"]
)
members: dict[str, "Schema"] = field(default_factory=dict[str, "Schema"])
member_target: "Schema | None" = None
member_index: int | None = None

Expand Down
11 changes: 7 additions & 4 deletions packages/smithy-core/tests/unit/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# pyright: reportPrivateUsage=false
from datetime import UTC, datetime
from typing import Any, assert_type
from typing import Any, assert_type, cast

import pytest
from smithy_core.exceptions import ExpectationNotMetError
Expand Down Expand Up @@ -323,9 +323,12 @@ def test_properties_typed_pop() -> None:

def test_union_property() -> None:
properties = TypedProperties()
union: PropertyKey[str | int] = PropertyKey(
key="union",
value_type=str | int, # type: ignore
union = cast(
PropertyKey[str | int],
PropertyKey(
key="union",
value_type=str | int, # type: ignore
),
)

properties[union] = "foo"
Expand Down
9 changes: 5 additions & 4 deletions packages/smithy-http/src/smithy_http/aio/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from copy import deepcopy
from functools import partial
from io import BufferedIOBase, BytesIO
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast

if TYPE_CHECKING:
# Both of these are types that essentially are "castable to bytes/memoryview"
Expand Down Expand Up @@ -119,7 +119,7 @@ def set_stream(self, stream: "crt_http.HttpClientStream") -> None:
if self._stream is not None:
raise SmithyHTTPError("Stream already set on AWSCRTHTTPResponse object")
self._stream = stream
concurrent_future: ConcurrentFuture[int] = stream.completion_future
concurrent_future = cast(ConcurrentFuture[int], stream.completion_future)
self._completion_future = asyncio.wrap_future(concurrent_future)
self._completion_future.add_done_callback(self._on_complete)
self._stream.activate()
Expand Down Expand Up @@ -305,14 +305,15 @@ def _build_new_connection(
if url.port is not None:
port = url.port

connect_future: ConcurrentFuture[crt_http.HttpClientConnection] = (
connect_future = cast(
ConcurrentFuture[crt_http.HttpClientConnection],
crt_http.HttpClientConnection.new(
bootstrap=self._client_bootstrap,
host_name=url.host,
port=port,
socket_options=self._socket_options,
tls_connection_options=tls_connection_options,
)
),
)
return connect_future

Expand Down
38 changes: 28 additions & 10 deletions packages/smithy-http/src/smithy_http/user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,35 @@ def __str__(self) -> str:

@dataclass(kw_only=True, slots=True)
class UserAgent:
sdk_metadata: list[UserAgentComponent] = field(default_factory=list)
internal_metadata: list[UserAgentComponent] = field(default_factory=list)
ua_metadata: list[UserAgentComponent] = field(default_factory=list)
api_metadata: list[UserAgentComponent] = field(default_factory=list)
os_metadata: list[UserAgentComponent] = field(default_factory=list)
language_metadata: list[UserAgentComponent] = field(default_factory=list)
env_metadata: list[UserAgentComponent] = field(default_factory=list)
config_metadata: list[UserAgentComponent] = field(default_factory=list)
feat_metadata: list[UserAgentComponent] = field(default_factory=list)
sdk_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
internal_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
ua_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
api_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
os_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
language_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
env_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
config_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
feat_metadata: list[UserAgentComponent] = field(
default_factory=list[UserAgentComponent]
)
additional_metadata: list[UserAgentComponent | RawStringUserAgentComponent] = field(
default_factory=list
default_factory=list[UserAgentComponent | RawStringUserAgentComponent]
)

def __str__(self) -> str:
Expand Down
20 changes: 11 additions & 9 deletions packages/smithy-http/tests/unit/test_serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,27 +124,29 @@
@dataclass
class _HTTPMapping(Protocol):
boolean_member: bool | None = None
boolean_list_member: list[bool] = field(default_factory=list)
boolean_list_member: list[bool] = field(default_factory=list[bool])
integer_member: int | None = None
integer_list_member: list[int] = field(default_factory=list)
integer_list_member: list[int] = field(default_factory=list[int])
float_member: float | None = None
float_list_member: list[float] = field(default_factory=list)
float_list_member: list[float] = field(default_factory=list[float])
big_decimal_member: Decimal | None = None
big_decimal_list_member: list[Decimal] = field(default_factory=list)
big_decimal_list_member: list[Decimal] = field(default_factory=list[Decimal])
string_member: str | None = None
string_list_member: list[str] = field(default_factory=list)
string_list_member: list[str] = field(default_factory=list[str])
default_timestamp_member: datetime.datetime | None = None
http_date_timestamp_member: datetime.datetime | None = None
http_date_list_timestamp_member: list[datetime.datetime] = field(
default_factory=list
default_factory=list[datetime.datetime]
)
date_time_timestamp_member: datetime.datetime | None = None
date_time_list_timestamp_member: list[datetime.datetime] = field(
default_factory=list
default_factory=list[datetime.datetime]
)
epoch_timestamp_member: datetime.datetime | None = None
epoch_list_timestamp_member: list[datetime.datetime] = field(default_factory=list)
string_map_member: dict[str, str] = field(default_factory=dict)
epoch_list_timestamp_member: list[datetime.datetime] = field(
default_factory=list[datetime.datetime]
)
string_map_member: dict[str, str] = field(default_factory=dict[str, str])

ID: ClassVar[ShapeID]
SCHEMA: ClassVar[Schema]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dependencies = []
dev = [
"black>=25.1.0",
"docformatter>=1.7.5",
"pyright>=1.1.396",
"pyright>=1.1.400",
"pytest>=8.3.4",
"pytest-asyncio>=0.25.3",
"pytest-cov>=6.0.0",
Expand Down
Loading