diff --git a/python_modules/dagster/dagster/_core/snap/node.py b/python_modules/dagster/dagster/_core/snap/node.py index e05f6d972411a..df8fbc83034ee 100644 --- a/python_modules/dagster/dagster/_core/snap/node.py +++ b/python_modules/dagster/dagster/_core/snap/node.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import Mapping, NamedTuple, Optional, Sequence, Union import dagster._check as check @@ -233,6 +234,14 @@ def __new__( ), ) + @cached_property + def input_def_map(self) -> Mapping[str, InputDefSnap]: + return {input_def.name: input_def for input_def in self.input_def_snaps} + + @cached_property + def output_def_map(self) -> Mapping[str, OutputDefSnap]: + return {output_def.name: output_def for output_def in self.output_def_snaps} + def get_input_snap(self, name: str) -> InputDefSnap: return _get_input_snap(self, name) @@ -282,6 +291,14 @@ def __new__( ), ) + @cached_property + def input_def_map(self) -> Mapping[str, InputDefSnap]: + return {input_def.name: input_def for input_def in self.input_def_snaps} + + @cached_property + def output_def_map(self) -> Mapping[str, OutputDefSnap]: + return {output_def.name: output_def for output_def in self.output_def_snaps} + def get_input_snap(self, name: str) -> InputDefSnap: return _get_input_snap(self, name) @@ -387,9 +404,9 @@ def build_op_def_snap(op_def: OpDefinition) -> OpDefSnap: # shared impl for GraphDefSnap and OpDefSnap def _get_input_snap(node_def: Union[GraphDefSnap, OpDefSnap], name: str) -> InputDefSnap: check.str_param(name, "name") - for inp in node_def.input_def_snaps: - if inp.name == name: - return inp + inp = node_def.input_def_map.get(name) + if inp: + return inp check.failed(f"Could not find input {name} in op def {node_def.name}") @@ -397,8 +414,8 @@ def _get_input_snap(node_def: Union[GraphDefSnap, OpDefSnap], name: str) -> Inpu # shared impl for GraphDefSnap and OpDefSnap def _get_output_snap(node_def: Union[GraphDefSnap, OpDefSnap], name: str) -> OutputDefSnap: check.str_param(name, "name") - for out in node_def.output_def_snaps: - if out.name == name: - return out + inp = node_def.output_def_map.get(name) + if inp: + return inp check.failed(f"Could not find output {name} in node def {node_def.name}")