Skip to content

Commit

Permalink
[record] snap/node.py (#26487)
Browse files Browse the repository at this point in the history
## How I Tested These Changes

bk
  • Loading branch information
alangenfeld authored Dec 16, 2024
1 parent 7d48a4d commit 51bdd79
Showing 1 changed file with 70 additions and 195 deletions.
265 changes: 70 additions & 195 deletions python_modules/dagster/dagster/_core/snap/node.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import cached_property
from typing import Mapping, NamedTuple, Optional, Sequence, Union
from typing import Mapping, Optional, Sequence, Union

import dagster._check as check
from dagster._config import ConfigFieldSnap, snap_from_field
Expand All @@ -21,6 +21,7 @@
DependencyStructureSnapshot,
build_dep_structure_snapshot_from_graph_def,
)
from dagster._record import IHaveNew, record, record_custom
from dagster._serdes import whitelist_for_serdes
from dagster._utils.warnings import suppress_dagster_warnings

Expand All @@ -30,29 +31,25 @@
field_serializers={"metadata": MetadataFieldSerializer},
skip_when_empty_fields={"metadata"},
)
class InputDefSnap(
NamedTuple(
"_InputDefSnap",
[
("name", str),
("dagster_type_key", str),
("description", Optional[str]),
("metadata", Mapping[str, MetadataValue]),
],
)
):
@record_custom
class InputDefSnap(IHaveNew):
name: str
dagster_type_key: str
description: Optional[str]
metadata: Mapping[str, MetadataValue]

def __new__(
cls,
name: str,
dagster_type_key: str,
description: Optional[str],
metadata: Optional[Mapping[str, MetadataValue]] = None,
):
return super(InputDefSnap, cls).__new__(
return super().__new__(
cls,
name=check.str_param(name, "name"),
dagster_type_key=check.str_param(dagster_type_key, "dagster_type_key"),
description=check.opt_str_param(description, "description"),
name=name,
dagster_type_key=dagster_type_key,
description=description,
metadata=normalize_metadata(
check.opt_mapping_param(metadata, "metadata", key_type=str), allow_invalid=True
),
Expand All @@ -64,19 +61,15 @@ def __new__(
field_serializers={"metadata": MetadataFieldSerializer},
skip_when_empty_fields={"metadata"},
)
class OutputDefSnap(
NamedTuple(
"_OutputDefSnap",
[
("name", str),
("dagster_type_key", str),
("description", Optional[str]),
("is_required", bool),
("metadata", Mapping[str, MetadataValue]),
("is_dynamic", bool),
],
)
):
@record_custom
class OutputDefSnap(IHaveNew):
name: str
dagster_type_key: str
description: Optional[str]
is_required: bool
metadata: Mapping[str, MetadataValue]
is_dynamic: bool

def __new__(
cls,
name: str,
Expand All @@ -86,42 +79,25 @@ def __new__(
metadata: Optional[Mapping[str, MetadataValue]] = None,
is_dynamic: bool = False,
):
return super(OutputDefSnap, cls).__new__(
return super().__new__(
cls,
name=check.str_param(name, "name"),
dagster_type_key=check.str_param(dagster_type_key, "dagster_type_key"),
description=check.opt_str_param(description, "description"),
is_required=check.bool_param(is_required, "is_required"),
name=name,
dagster_type_key=dagster_type_key,
description=description,
is_required=is_required,
metadata=normalize_metadata(
check.opt_mapping_param(metadata, "metadata", key_type=str), allow_invalid=True
),
is_dynamic=check.bool_param(is_dynamic, "is_dynamic"),
is_dynamic=is_dynamic,
)


@whitelist_for_serdes(storage_field_names={"mapped_node_name": "mapped_solid_name"})
class OutputMappingSnap(
NamedTuple(
"_OutputMappingSnap",
[
("mapped_node_name", str),
("mapped_output_name", str),
("external_output_name", str),
],
)
):
def __new__(
cls,
mapped_node_name: str,
mapped_output_name: str,
external_output_name: str,
):
return super(OutputMappingSnap, cls).__new__(
cls,
mapped_node_name=check.str_param(mapped_node_name, "mapped_node_name"),
mapped_output_name=check.str_param(mapped_output_name, "mapped_output_name"),
external_output_name=check.str_param(external_output_name, "external_output_name"),
)
@record
class OutputMappingSnap:
mapped_node_name: str
mapped_output_name: str
external_output_name: str


def build_output_mapping_snap(output_mapping: OutputMapping) -> OutputMappingSnap:
Expand All @@ -133,23 +109,11 @@ def build_output_mapping_snap(output_mapping: OutputMapping) -> OutputMappingSna


@whitelist_for_serdes(storage_field_names={"mapped_node_name": "mapped_solid_name"})
class InputMappingSnap(
NamedTuple(
"_InputMappingSnap",
[
("mapped_node_name", str),
("mapped_input_name", str),
("external_input_name", str),
],
)
):
def __new__(cls, mapped_node_name: str, mapped_input_name: str, external_input_name: str):
return super(InputMappingSnap, cls).__new__(
cls,
mapped_node_name=check.str_param(mapped_node_name, "mapped_node_name"),
mapped_input_name=check.str_param(mapped_input_name, "mapped_input_name"),
external_input_name=check.str_param(external_input_name, "external_input_name"),
)
@record
class InputMappingSnap:
mapped_node_name: str
mapped_input_name: str
external_input_name: str


def build_input_mapping_snap(input_mapping: InputMapping) -> InputMappingSnap:
Expand Down Expand Up @@ -183,56 +147,17 @@ def build_output_def_snap(output_def: OutputDefinition) -> OutputDefSnap:


@whitelist_for_serdes(storage_name="CompositeSolidDefSnap")
class GraphDefSnap(
NamedTuple(
"_GraphDefSnap",
[
("name", str),
("input_def_snaps", Sequence[InputDefSnap]),
("output_def_snaps", Sequence[OutputDefSnap]),
("description", Optional[str]),
("tags", Mapping[str, object]),
("config_field_snap", Optional[ConfigFieldSnap]),
("dep_structure_snapshot", DependencyStructureSnapshot),
("input_mapping_snaps", Sequence[InputMappingSnap]),
("output_mapping_snaps", Sequence[OutputMappingSnap]),
],
)
):
def __new__(
cls,
name: str,
input_def_snaps: Sequence[InputDefSnap],
output_def_snaps: Sequence[OutputDefSnap],
description: Optional[str],
tags: Mapping[str, str],
config_field_snap: Optional[ConfigFieldSnap],
dep_structure_snapshot: DependencyStructureSnapshot,
input_mapping_snaps: Sequence[InputMappingSnap],
output_mapping_snaps: Sequence[OutputMappingSnap],
):
return super(GraphDefSnap, cls).__new__(
cls,
dep_structure_snapshot=check.inst_param(
dep_structure_snapshot, "dep_structure_snapshot", DependencyStructureSnapshot
),
input_mapping_snaps=check.sequence_param(
input_mapping_snaps, "input_mapping_snaps", of_type=InputMappingSnap
),
output_mapping_snaps=check.sequence_param(
output_mapping_snaps, "output_mapping_snaps", of_type=OutputMappingSnap
),
name=check.str_param(name, "name"),
input_def_snaps=check.sequence_param(input_def_snaps, "input_def_snaps", InputDefSnap),
output_def_snaps=check.sequence_param(
output_def_snaps, "output_def_snaps", OutputDefSnap
),
description=check.opt_str_param(description, "description"),
tags=check.mapping_param(tags, "tags"),
config_field_snap=check.opt_inst_param(
config_field_snap, "config_field_snap", ConfigFieldSnap
),
)
@record
class GraphDefSnap:
name: str
input_def_snaps: Sequence[InputDefSnap]
output_def_snaps: Sequence[OutputDefSnap]
description: Optional[str]
tags: Mapping[str, str]
config_field_snap: Optional[ConfigFieldSnap]
dep_structure_snapshot: DependencyStructureSnapshot
input_mapping_snaps: Sequence[InputMappingSnap]
output_mapping_snaps: Sequence[OutputMappingSnap]

@cached_property
def input_def_map(self) -> Mapping[str, InputDefSnap]:
Expand All @@ -250,46 +175,15 @@ def get_output_snap(self, name: str) -> OutputDefSnap:


@whitelist_for_serdes(storage_name="SolidDefSnap")
class OpDefSnap(
NamedTuple(
"_OpDefSnap",
[
("name", str),
("input_def_snaps", Sequence[InputDefSnap]),
("output_def_snaps", Sequence[OutputDefSnap]),
("description", Optional[str]),
("tags", Mapping[str, object]),
("required_resource_keys", Sequence[str]),
("config_field_snap", Optional[ConfigFieldSnap]),
],
)
):
def __new__(
cls,
name: str,
input_def_snaps: Sequence[InputDefSnap],
output_def_snaps: Sequence[OutputDefSnap],
description: Optional[str],
tags: Mapping[str, str],
required_resource_keys: Sequence[str],
config_field_snap: Optional[ConfigFieldSnap],
):
return super(OpDefSnap, cls).__new__(
cls,
required_resource_keys=check.sequence_param(
required_resource_keys, "required_resource_keys", str
),
name=check.str_param(name, "name"),
input_def_snaps=check.sequence_param(input_def_snaps, "input_def_snaps", InputDefSnap),
output_def_snaps=check.sequence_param(
output_def_snaps, "output_def_snaps", OutputDefSnap
),
description=check.opt_str_param(description, "description"),
tags=check.mapping_param(tags, "tags"),
config_field_snap=check.opt_inst_param(
config_field_snap, "config_field_snap", ConfigFieldSnap
),
)
@record
class OpDefSnap:
name: str
input_def_snaps: Sequence[InputDefSnap]
output_def_snaps: Sequence[OutputDefSnap]
description: Optional[str]
tags: Mapping[str, str]
required_resource_keys: Sequence[str]
config_field_snap: Optional[ConfigFieldSnap]

@cached_property
def input_def_map(self) -> Mapping[str, InputDefSnap]:
Expand All @@ -313,35 +207,10 @@ def get_output_snap(self, name: str) -> OutputDefSnap:
"graph_def_snaps": "composite_solid_def_snaps",
},
)
class NodeDefsSnapshot(
NamedTuple(
"_NodeDefsSnapshot",
[
("op_def_snaps", Sequence[OpDefSnap]),
("graph_def_snaps", Sequence[GraphDefSnap]),
],
)
):
def __new__(
cls,
op_def_snaps: Sequence[OpDefSnap],
graph_def_snaps: Sequence[GraphDefSnap],
):
return super(NodeDefsSnapshot, cls).__new__(
cls,
op_def_snaps=sorted(
check.sequence_param(op_def_snaps, "op_def_snaps", of_type=OpDefSnap),
key=lambda op_def: op_def.name,
),
graph_def_snaps=sorted(
check.sequence_param(
graph_def_snaps,
"graph_def_snaps",
of_type=GraphDefSnap,
),
key=lambda graph_def: graph_def.name,
),
)
@record
class NodeDefsSnapshot(IHaveNew):
op_def_snaps: Sequence[OpDefSnap]
graph_def_snaps: Sequence[GraphDefSnap]


@suppress_dagster_warnings
Expand All @@ -358,8 +227,14 @@ def build_node_defs_snapshot(job_def: JobDefinition) -> NodeDefsSnapshot:
check.failed(f"Unexpected NodeDefinition type {node_def}")

return NodeDefsSnapshot(
op_def_snaps=op_def_snaps,
graph_def_snaps=graph_def_snaps,
op_def_snaps=sorted(
op_def_snaps,
key=lambda op_def: op_def.name,
),
graph_def_snaps=sorted(
graph_def_snaps,
key=lambda graph_def: graph_def.name,
),
)


Expand Down

0 comments on commit 51bdd79

Please sign in to comment.