Skip to content

Commit

Permalink
helpers
Browse files Browse the repository at this point in the history
  • Loading branch information
schrockn committed Dec 24, 2024
1 parent ae19563 commit 1ef7294
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def from_entry_point_discovery(
`dagster_components*`. Only one built-in component library can be loaded at a time.
Defaults to `dagster_components`, the standard set of published component types.
"""
components: Dict[str, Type[Component]] = {}
component_types: Dict[str, Type[Component]] = {}
for entry_point in get_entry_points_from_python_environment(COMPONENTS_ENTRY_POINT_GROUP):
# Skip built-in entry points that are not the specified builtin component library.
if (
Expand All @@ -163,44 +163,44 @@ def from_entry_point_discovery(
f"Invalid entry point {entry_point.name} in group {COMPONENTS_ENTRY_POINT_GROUP}. "
f"Value expected to be a module, got {root_module}."
)
for component in get_registered_components_in_module(root_module):
key = f"{entry_point.name}.{get_component_name(component)}"
components[key] = component
for component_type in get_registered_component_types_in_module(root_module):
key = f"{entry_point.name}.{get_component_type_name(component_type)}"
component_types[key] = component_type

return cls(components)
return cls(component_types)

def __init__(self, components: Dict[str, Type[Component]]):
self._components: Dict[str, Type[Component]] = copy.copy(components)
def __init__(self, component_types: Dict[str, Type[Component]]):
self._component_types: Dict[str, Type[Component]] = copy.copy(component_types)

@staticmethod
def empty() -> "ComponentTypeRegistry":
return ComponentTypeRegistry({})

def register(self, name: str, component: Type[Component]) -> None:
if name in self._components:
def register(self, name: str, component_type: Type[Component]) -> None:
if name in self._component_types:
raise DagsterError(f"There is an existing component registered under {name}")
self._components[name] = component
self._component_types[name] = component_type

def has(self, name: str) -> bool:
return name in self._components
return name in self._component_types

def get(self, name: str) -> Type[Component]:
return self._components[name]
return self._component_types[name]

def keys(self) -> Iterable[str]:
return self._components.keys()
return self._component_types.keys()

def __repr__(self) -> str:
return f"<ComponentRegistry {list(self._components.keys())}>"
return f"<ComponentRegistry {list(self._component_types.keys())}>"


def get_registered_components_in_module(module: ModuleType) -> Iterable[Type[Component]]:
def get_registered_component_types_in_module(module: ModuleType) -> Iterable[Type[Component]]:
from dagster._core.definitions.module_loaders.load_assets_from_modules import (
find_subclasses_in_module,
)

for component in find_subclasses_in_module(module, (Component,)):
if is_registered_component(component):
if is_registered_component_type(component):
yield component


Expand Down Expand Up @@ -294,13 +294,13 @@ def wrapper(actual_cls: Type[Component]) -> Type[Component]:
return cls


def is_registered_component(cls: Type) -> bool:
def is_registered_component_type(cls: Type) -> bool:
return hasattr(cls, COMPONENT_REGISTRY_KEY_ATTR)


def get_component_name(component_type: Type[Component]) -> str:
def get_component_type_name(component_type: Type[Component]) -> str:
check.param_invariant(
is_registered_component(component_type),
is_registered_component_type(component_type),
"component_type",
"Expected a registered component. Use @component to register a component.",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
ComponentLoadContext,
ComponentTypeRegistry,
TemplatedValueResolver,
get_component_name,
is_registered_component,
get_component_type_name,
is_registered_component_type,
)
from dagster_components.core.component_decl_builder import (
ComponentFolder,
Expand Down Expand Up @@ -69,8 +69,8 @@ def component_type_from_yaml_decl(
for _name, obj in inspect.getmembers(module, inspect.isclass):
assert isinstance(obj, Type)
if (
is_registered_component(obj)
and get_component_name(obj) == component_registry_key
is_registered_component_type(obj)
and get_component_type_name(obj) == component_registry_key
):
return obj

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from dagster._core.definitions.definitions_class import Definitions
from dagster_components.core.component import (
ComponentTypeRegistry,
get_registered_components_in_module,
get_registered_component_types_in_module,
)
from dagster_components.core.component_defs_builder import (
build_defs_from_component_path,
get_component_name,
get_component_type_name,
)
from dagster_components.core.deployment import CodeLocationProjectContext

Expand All @@ -27,8 +27,8 @@ def load_test_component_project_context() -> CodeLocationProjectContext:
dc_module = importlib.import_module(package_name)

components = {}
for component in get_registered_components_in_module(dc_module):
key = f"dagster_components.{get_component_name(component)}"
for component in get_registered_component_types_in_module(dc_module):
key = f"dagster_components.{get_component_type_name(component)}"
components[key] = component

return CodeLocationProjectContext(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
from dagster_components import Component, component_type
from dagster_components.core.component import get_component_name, is_registered_component
from dagster_components.core.component import get_component_type_name, is_registered_component_type


def test_registered_component_with_default_name() -> None:
@component_type
class RegisteredComponent(Component): ...

assert is_registered_component(RegisteredComponent)
assert get_component_name(RegisteredComponent) == "registered_component"
assert is_registered_component_type(RegisteredComponent)
assert get_component_type_name(RegisteredComponent) == "registered_component"


def test_registered_component_with_default_name_and_parens() -> None:
@component_type()
class RegisteredComponent(Component): ...

assert is_registered_component(RegisteredComponent)
assert get_component_name(RegisteredComponent) == "registered_component"
assert is_registered_component_type(RegisteredComponent)
assert get_component_type_name(RegisteredComponent) == "registered_component"


def test_registered_component_with_explicit_kwarg_name() -> None:
@component_type(name="explicit_name")
class RegisteredComponent(Component): ...

assert is_registered_component(RegisteredComponent)
assert get_component_name(RegisteredComponent) == "explicit_name"
assert is_registered_component_type(RegisteredComponent)
assert get_component_type_name(RegisteredComponent) == "explicit_name"


def test_unregistered_component() -> None:
class UnregisteredComponent(Component): ...

assert not is_registered_component(UnregisteredComponent)
assert not is_registered_component_type(UnregisteredComponent)

0 comments on commit 1ef7294

Please sign in to comment.