Skip to content

Commit

Permalink
cache k8s container context
Browse files Browse the repository at this point in the history
  • Loading branch information
abdh committed Oct 30, 2024
1 parent b5d04dd commit d2a91da
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 53 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from typing import Any, Mapping, NamedTuple, Optional, TypeVar, cast
from typing import Any, Mapping, Optional, Union, cast

from dagster import (
Field,
IOManager,
JsonMetadataValue,
StepRunRef,
TypeCheck,
_check as check,
executor,
usable_as_dagster_type,
)
from dagster._core.definitions.executor_definition import multiple_process_executor_requirements
from dagster._core.events import EngineEventData
from dagster._core.execution.context.system import StepOrchestrationContext
from dagster._core.execution.plan.external_step import step_run_ref_to_step_context
from dagster._core.execution.plan.inputs import (
Expand All @@ -26,7 +28,7 @@

from dagster_k8s.container_context import K8sContainerContext
from dagster_k8s.executor import _K8S_EXECUTOR_CONFIG_SCHEMA, K8sStepHandler
from dagster_k8s.job import UserDefinedDagsterK8sConfig
from dagster_k8s.job import USER_DEFINED_K8S_CONFIG_KEY, UserDefinedDagsterK8sConfig
from dagster_k8s.launcher import K8sRunLauncher

_K8S_OP_EXECUTOR_CONFIG_SCHEMA = merge_dicts(
Expand All @@ -40,27 +42,25 @@
},
)

T = TypeVar("T")


@usable_as_dagster_type
class K8sOpMutatingOutput:
k8s_config: UserDefinedDagsterK8sConfig
value: Any

def __init__(
self, k8s_config: UserDefinedDagsterK8sConfig | Mapping[str, Any], value: Any = None
self, k8s_config: Union[UserDefinedDagsterK8sConfig, Mapping[str, Any]], value: Any = None
):
if isinstance(k8s_config, dict):
check.mapping_param(k8s_config, "k8s_config", str)
k8s_config = UserDefinedDagsterK8sConfig.from_dict(k8s_config)
else:
check.inst_param(k8s_config, "k8s_config", UserDefinedDagsterK8sConfig)
self.k8s_config = k8s_config
self.k8s_config = cast(UserDefinedDagsterK8sConfig, k8s_config)
self.value = value

@classmethod
def type_check(cls, _context, value):
def type_check(cls, _context, value: "K8sOpMutatingOutput"):
_ = _context
if not isinstance(value, cls):
return TypeCheck(
Expand All @@ -69,63 +69,61 @@ def type_check(cls, _context, value):
return TypeCheck(success=True)


class StepInputCacheKey(NamedTuple):
run_id: str
step_key: str
input_name: str


class K8sMutatingStepHandler(K8sStepHandler):
"""Specialized step handler that configure the next op based on the op metadata of the prior op."""

op_mutation_enabled: bool
step_input_cache: dict[StepInputCacheKey, UserDefinedDagsterK8sConfig]
# cache for container contexts
container_ctx_cache: dict[str, K8sContainerContext]
# consider adding a step output cache. Helps for reused op outputs, but could potentially
# store outputs that are never used again. Maybe LRU cache?
# step_output_cache: OrderedDict[StepInputCacheKey, UserDefinedDagsterK8sConfig]

def __init__(
self,
image: str | None,
image: Optional[str],
container_context: K8sContainerContext,
load_incluster_config: bool,
kubeconfig_file: str | None,
kubeconfig_file: Optional[str],
k8s_client_batch_api=None,
op_mutation_enabled: bool = False,
):
self.op_mutation_enabled = op_mutation_enabled
self.step_input_cache = {}
self.container_ctx_cache = {}
super().__init__(
image, container_context, load_incluster_config, kubeconfig_file, k8s_client_batch_api
)

@property
def name(self) -> str:
return "K8sMutatingStepHandler"

def _get_input_metadata(
self, step_input: StepInput, step_context: StepOrchestrationContext
) -> Optional[UserDefinedDagsterK8sConfig]:
# dagster type names should be garunteed unique, so this should be safe
if step_input.dagster_type_key != K8sOpMutatingOutput.__name__:
return {}
step_cache_key = StepInputCacheKey(
step_context.run_id, step_context.step.key, step_input.name
)
if step_cache_key in self.step_input_cache:
step_context.log.debug(f"cache hit for key {step_cache_key}")
return self.step_input_cache[step_cache_key]
return None
source = step_input.source
if isinstance(source, FromPendingDynamicStepOutput):
upstream_output_handle = source.step_output_handle
# coerce FromPendingDynamicStepOutput to FromStepOutputSource
assert isinstance(upstream_output_handle.mapping_key, str)
source = source.resolve(upstream_output_handle.mapping_key)
if not isinstance(source, FromStepOutput):
return step_context.log.error(
f"unable to consume {step_input.name}. FromStepOuput & FromPendingDynamicStepOutput sources supported, got {type(source)}"
)
if source.fan_in:
# this should never happen, but incase it does, log it
return step_context.log.error("fan in step input not supported")
job_def = step_context.job.get_definition()
# lie, cheat, and steal. Create step execution context ahead of time
step_run_ref = StepRunRef(
run_config=job_def.run_config or {},
dagster_run=step_context.dagster_run,
run_id=step_context.run_id,
step_key=step_context.step.key,
step_key=source.step_output_handle.step_key,
retry_mode=step_context.retry_mode,
recon_job=step_context.reconstructable_job,
known_state=step_context.execution_plan.known_state,
Expand Down Expand Up @@ -160,16 +158,29 @@ def _get_input_metadata(
# load full input via input manager. DANGER, if value is excessively large this could have perf impacts on stephandler
input_context = source.get_load_context(step_exec_context, input_def, manager_key)
op_mutating_output: K8sOpMutatingOutput = input_manager.load_input(input_context)
step_context.log.debug(f"eagerly recieved input {op_mutating_output}")
self.step_input_cache[step_cache_key] = op_mutating_output.k8s_config
return self.step_input_cache[step_cache_key]
step_context.instance.report_engine_event(
f'eagerly using input "{input_def.name}" from upstream op "{source.step_output_handle.step_key}" as config input',
step_context.dagster_run,
EngineEventData(
metadata={
USER_DEFINED_K8S_CONFIG_KEY: JsonMetadataValue(
data=op_mutating_output.k8s_config.to_dict()
)
}
),
step_key=step_context.step.key,
run_id=step_context.run_id,
)
return op_mutating_output.k8s_config

def _resolve_input_configs(
self, step_handler_context: StepHandlerContext
) -> Optional[K8sContainerContext]:
"""Fetch all the configured k8s metadata for op inputs."""
step_key = self._get_step_key(step_handler_context)
step_context = step_handler_context.get_step_context(step_key)
step_context = cast(
StepOrchestrationContext, step_handler_context.get_step_context(step_key)
)
container_context = None
for step_input in step_context.step.step_inputs:
op_metadata_config = self._get_input_metadata(step_input, step_context)
Expand All @@ -185,21 +196,30 @@ def _resolve_input_configs(
def _get_container_context(
self, step_handler_context: StepHandlerContext
) -> K8sContainerContext:
# function should be safe to cache since it's idempotent
step_key = self._get_step_key(step_handler_context)
if step_key in self.container_ctx_cache:
return self.container_ctx_cache[step_key]
context = super()._get_container_context(step_handler_context)
self.container_ctx_cache[step_key] = context
if not self.op_mutation_enabled:
return context
return self.container_ctx_cache[step_key]
# only use cache when op mutation is enabled. Fallback to K8sStephandler otherwise.
merged_input_k8s_configs = self._resolve_input_configs(step_handler_context)
if merged_input_k8s_configs:
return context.merge(merged_input_k8s_configs)
return context
self.container_ctx_cache[step_key] = self.container_ctx_cache[step_key].merge(
merged_input_k8s_configs
)
return self.container_ctx_cache[step_key]

def terminate_step(self, step_handler_context):
yield from super().terminate_step(step_handler_context)
# pop cache to save mem for steps we won't visit again
step_key = self._get_step_key(step_handler_context)
stale_cache_keys = [k for k in self.step_input_cache if k.step_key == step_key]
for k in stale_cache_keys:
self.step_input_cache.pop(k)
def terminate_step(self, step_handler_context: StepHandlerContext):
try:
yield from super().terminate_step(step_handler_context)
finally:
# pop cache to save mem for steps we won't visit again
step_key = self._get_step_key(step_handler_context)
# entry might not exist if op mutation is disabled
self.container_ctx_cache.pop(step_key, None)


@executor(
Expand Down Expand Up @@ -237,8 +257,7 @@ def k8s_op_mutating_executor(init_context: InitExecutorContext) -> Executor:
kubeconfig_file = cast(Optional[str], exc_cfg["kubeconfig_file"])
else:
kubeconfig_file = run_launcher.kubeconfig_file if run_launcher else None
op_mutation_enabled = exc_cfg.get("op_mutation_enabled", False)
check.bool_param(op_mutation_enabled, "op_mutation_enabled")
op_mutation_enabled = check.bool_elem(exc_cfg, "op_mutation_enabled")
return StepDelegatingExecutor(
K8sMutatingStepHandler(
image=exc_cfg.get("job_image"), # type: ignore
Expand Down
Loading

0 comments on commit d2a91da

Please sign in to comment.