Skip to content

Commit

Permalink
[lazy-defs] [RFC] global DefinitionsLoadContext (#24566)
Browse files Browse the repository at this point in the history
## Summary & Motivation

Variant of reconstruction metadata API where, instead of definitions
being "lazily" defined in a function, they continue to be defined using
a standard `Definitions` instantiation. A `DefinitionsLoadContext`
instance is made available via `DefinitionsLoadContext.get()`-- this is
set prior to loading a repository.

The upshot is that users do not have to change their entry point code.
Integrations can invoke `DefinitionsLoadContext.get()` to access the
context without the user having to pass it in.

## How I Tested These Changes

New unit tests.

## Changelog

NOCHANGELOG

- [ ] `NEW` _(added new feature or capability)_
- [ ] `BUGFIX` _(fixed a bug)_
- [ ] `DOCS` _(added or updated documentation)_
  • Loading branch information
smackesey authored Sep 19, 2024
1 parent a48b3e6 commit 78eaca0
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from enum import Enum
from functools import cached_property
from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Mapping, Optional, Union

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -45,6 +45,8 @@ class DefinitionsLoadContext:
User construction of this object is not supported.
"""

_instance: ClassVar[Optional["DefinitionsLoadContext"]] = None

def __init__(
self,
load_type: DefinitionsLoadType,
Expand All @@ -53,6 +55,20 @@ def __init__(
self._load_type = load_type
self._repository_load_data = repository_load_data

@classmethod
def get(cls) -> "DefinitionsLoadContext":
"""Get the current DefinitionsLoadContext."""
if not DefinitionsLoadContext._instance:
raise DagsterInvariantViolationError(
"Attempted to access the global DefinitionsLoadContext before it has been set."
)
return DefinitionsLoadContext._instance

@classmethod
def set(cls, instance: "DefinitionsLoadContext") -> None:
"""Get the current DefinitionsLoadContext."""
cls._instance = instance

@property
def load_type(self) -> DefinitionsLoadType:
"""DefinitionsLoadType: Classifier for scenario in which Definitions are being loaded."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,10 @@ def repository_def_from_target_def(
)
from dagster._core.definitions.source_asset import SourceAsset

DefinitionsLoadContext.set(
DefinitionsLoadContext(load_type=load_type, repository_load_data=repository_load_data)
)

# DefinitionsLoader will always return Definitions
if isinstance(target, DefinitionsLoader):
context = (
Expand Down Expand Up @@ -773,6 +777,11 @@ def repository_def_from_pointer(
load_type: "DefinitionsLoadType",
repository_load_data: Optional["RepositoryLoadData"] = None,
) -> "RepositoryDefinition":
from dagster._core.definitions.definitions_loader import DefinitionsLoadContext

DefinitionsLoadContext.set(
DefinitionsLoadContext(load_type=load_type, repository_load_data=repository_load_data)
)
target = def_from_pointer(pointer)
repo_def = repository_def_from_target_def(target, load_type, repository_load_data)
if not repo_def:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from dagster._core.definitions.asset_spec import AssetSpec
from dagster._core.definitions.decorators.asset_decorator import asset
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.definitions_loader import DefinitionsLoadContext, DefinitionsLoadType
from dagster._core.definitions.external_asset import external_assets_from_specs

from dagster_tests.definitions_tests.test_definitions_loader import fetch_foo_integration_asset_info

FOO_INTEGRATION_SOURCE_KEY = "foo_integration"

WORKSPACE_ID = "my_workspace"


# This function would be provided by integration lib dagster-foo
def _get_foo_integration_defs(workspace_id: str) -> Definitions:
context = DefinitionsLoadContext.get()
metadata_key = f"{FOO_INTEGRATION_SOURCE_KEY}/{workspace_id}"
if (
context.load_type == DefinitionsLoadType.RECONSTRUCTION
and metadata_key in context.reconstruction_metadata
):
payload = context.reconstruction_metadata[metadata_key]
else:
payload = fetch_foo_integration_asset_info(workspace_id)
asset_specs = [AssetSpec(item["id"]) for item in payload]
assets = external_assets_from_specs(asset_specs)
return Definitions(
assets=assets,
).with_reconstruction_metadata({metadata_key: payload})


@asset
def regular_asset(): ...


defs = Definitions.merge(
_get_foo_integration_defs(WORKSPACE_ID),
Definitions(assets=[regular_asset]),
)
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dagster._core.definitions.definitions_class import Definitions
from dagster._core.definitions.definitions_loader import DefinitionsLoadContext, DefinitionsLoadType
from dagster._core.definitions.external_asset import external_assets_from_specs
from dagster._core.definitions.metadata.metadata_value import MetadataValue
from dagster._core.definitions.reconstruct import (
ReconstructableJob,
ReconstructableRepository,
Expand All @@ -25,13 +26,14 @@
from dagster._core.errors import DagsterInvalidInvocationError
from dagster._core.execution.api import execute_job
from dagster._core.instance_for_test import instance_for_test
from dagster._utils import file_relative_path

FOO_INTEGRATION_SOURCE_KEY = "foo_integration"

WORKSPACE_ID = "my_workspace"


def _fetch_foo_integration_asset_info(workspace_id: str):
def fetch_foo_integration_asset_info(workspace_id: str):
if workspace_id == WORKSPACE_ID:
return [{"id": "alpha"}, {"id": "beta"}]
else:
Expand All @@ -47,7 +49,7 @@ def _get_foo_integration_defs(context: DefinitionsLoadContext, workspace_id: str
):
payload = context.reconstruction_metadata[cache_key]
else:
payload = _fetch_foo_integration_asset_info(workspace_id)
payload = fetch_foo_integration_asset_info(workspace_id)
asset_specs = [AssetSpec(item["id"]) for item in payload]
assets = external_assets_from_specs(asset_specs)
return Definitions(
Expand Down Expand Up @@ -92,20 +94,45 @@ def test_reconstruction_metadata():
repo_load_data = RepositoryLoadData(
cacheable_asset_data={},
reconstruction_metadata={
f"{FOO_INTEGRATION_SOURCE_KEY}/{WORKSPACE_ID}": _fetch_foo_integration_asset_info(
WORKSPACE_ID
f"{FOO_INTEGRATION_SOURCE_KEY}/{WORKSPACE_ID}": MetadataValue.code_location_reconstruction(
fetch_foo_integration_asset_info(WORKSPACE_ID)
)
},
)

# Ensure we don't call the expensive fetch function when we have the data cached
with patch(
"dagster_tests.definitions_tests.test_definitions_loader._fetch_foo_integration_asset_info"
"dagster_tests.definitions_tests.test_definitions_loader.fetch_foo_integration_asset_info"
) as mock_fetch:
inner_repo.reconstruct_repository_definition(repository_load_data=repo_load_data)
mock_fetch.assert_not_called()


def test_reconstruction_metadata_with_global_context():
defs_path = file_relative_path(__file__, "metadata_defs_global_context.py")

recon_repo = ReconstructableRepository.for_file(defs_path, "defs")
assert isinstance(recon_repo.get_definition(), RepositoryDefinition)

recon_repo_with_cache = recon_repo.with_repository_load_data(
RepositoryLoadData(
cacheable_asset_data={},
reconstruction_metadata={
f"{FOO_INTEGRATION_SOURCE_KEY}/{WORKSPACE_ID}": MetadataValue.code_location_reconstruction(
fetch_foo_integration_asset_info(WORKSPACE_ID)
)
},
)
)

# Ensure we don't call the expensive fetch function when we have the data cached
with patch(
"dagster_tests.definitions_tests.test_definitions_loader.fetch_foo_integration_asset_info"
) as mock_fetch:
recon_repo_with_cache.get_definition()
mock_fetch.assert_not_called()


def test_invoke_definitions_loader_with_context():
@definitions
def defs(context: DefinitionsLoadContext) -> Definitions:
Expand Down

0 comments on commit 78eaca0

Please sign in to comment.