From d2a91daff0346265a3b8fbf9e3b1900ec1e0098d Mon Sep 17 00:00:00 2001 From: abdh Date: Wed, 30 Oct 2024 14:12:07 -0700 Subject: [PATCH] cache k8s container context --- .../dagster_k8s/op_mutating_executor.py | 101 ++++++++------ .../unit_tests/test_mutating_executor.py | 124 ++++++++++++++++-- 2 files changed, 172 insertions(+), 53 deletions(-) diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py index d68a18fd38f80..f2b61833294ee 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py @@ -1,8 +1,9 @@ -from typing import Any, Mapping, NamedTuple, Optional, TypeVar, cast +from typing import Any, Mapping, Optional, Union, cast from dagster import ( Field, IOManager, + JsonMetadataValue, StepRunRef, TypeCheck, _check as check, @@ -10,6 +11,7 @@ usable_as_dagster_type, ) from dagster._core.definitions.executor_definition import multiple_process_executor_requirements +from dagster._core.events import EngineEventData from dagster._core.execution.context.system import StepOrchestrationContext from dagster._core.execution.plan.external_step import step_run_ref_to_step_context from dagster._core.execution.plan.inputs import ( @@ -26,7 +28,7 @@ from dagster_k8s.container_context import K8sContainerContext from dagster_k8s.executor import _K8S_EXECUTOR_CONFIG_SCHEMA, K8sStepHandler -from dagster_k8s.job import UserDefinedDagsterK8sConfig +from dagster_k8s.job import USER_DEFINED_K8S_CONFIG_KEY, UserDefinedDagsterK8sConfig from dagster_k8s.launcher import K8sRunLauncher _K8S_OP_EXECUTOR_CONFIG_SCHEMA = merge_dicts( @@ -40,8 +42,6 @@ }, ) -T = TypeVar("T") - @usable_as_dagster_type class K8sOpMutatingOutput: @@ -49,18 +49,18 @@ class K8sOpMutatingOutput: value: Any def __init__( - self, k8s_config: UserDefinedDagsterK8sConfig | Mapping[str, Any], value: Any = None + self, k8s_config: Union[UserDefinedDagsterK8sConfig, Mapping[str, Any]], value: Any = None ): if isinstance(k8s_config, dict): check.mapping_param(k8s_config, "k8s_config", str) k8s_config = UserDefinedDagsterK8sConfig.from_dict(k8s_config) else: check.inst_param(k8s_config, "k8s_config", UserDefinedDagsterK8sConfig) - self.k8s_config = k8s_config + self.k8s_config = cast(UserDefinedDagsterK8sConfig, k8s_config) self.value = value @classmethod - def type_check(cls, _context, value): + def type_check(cls, _context, value: "K8sOpMutatingOutput"): _ = _context if not isinstance(value, cls): return TypeCheck( @@ -69,55 +69,53 @@ def type_check(cls, _context, value): return TypeCheck(success=True) -class StepInputCacheKey(NamedTuple): - run_id: str - step_key: str - input_name: str - - class K8sMutatingStepHandler(K8sStepHandler): """Specialized step handler that configure the next op based on the op metadata of the prior op.""" op_mutation_enabled: bool - step_input_cache: dict[StepInputCacheKey, UserDefinedDagsterK8sConfig] + # cache for container contexts + container_ctx_cache: dict[str, K8sContainerContext] + # consider adding a step output cache. Helps for reused op outputs, but could potentially + # store outputs that are never used again. Maybe LRU cache? + # step_output_cache: OrderedDict[StepInputCacheKey, UserDefinedDagsterK8sConfig] def __init__( self, - image: str | None, + image: Optional[str], container_context: K8sContainerContext, load_incluster_config: bool, - kubeconfig_file: str | None, + kubeconfig_file: Optional[str], k8s_client_batch_api=None, op_mutation_enabled: bool = False, ): self.op_mutation_enabled = op_mutation_enabled - self.step_input_cache = {} + self.container_ctx_cache = {} super().__init__( image, container_context, load_incluster_config, kubeconfig_file, k8s_client_batch_api ) + @property + def name(self) -> str: + return "K8sMutatingStepHandler" + def _get_input_metadata( self, step_input: StepInput, step_context: StepOrchestrationContext ) -> Optional[UserDefinedDagsterK8sConfig]: # dagster type names should be garunteed unique, so this should be safe if step_input.dagster_type_key != K8sOpMutatingOutput.__name__: - return {} - step_cache_key = StepInputCacheKey( - step_context.run_id, step_context.step.key, step_input.name - ) - if step_cache_key in self.step_input_cache: - step_context.log.debug(f"cache hit for key {step_cache_key}") - return self.step_input_cache[step_cache_key] + return None source = step_input.source if isinstance(source, FromPendingDynamicStepOutput): upstream_output_handle = source.step_output_handle # coerce FromPendingDynamicStepOutput to FromStepOutputSource + assert isinstance(upstream_output_handle.mapping_key, str) source = source.resolve(upstream_output_handle.mapping_key) if not isinstance(source, FromStepOutput): return step_context.log.error( f"unable to consume {step_input.name}. FromStepOuput & FromPendingDynamicStepOutput sources supported, got {type(source)}" ) if source.fan_in: + # this should never happen, but incase it does, log it return step_context.log.error("fan in step input not supported") job_def = step_context.job.get_definition() # lie, cheat, and steal. Create step execution context ahead of time @@ -125,7 +123,7 @@ def _get_input_metadata( run_config=job_def.run_config or {}, dagster_run=step_context.dagster_run, run_id=step_context.run_id, - step_key=step_context.step.key, + step_key=source.step_output_handle.step_key, retry_mode=step_context.retry_mode, recon_job=step_context.reconstructable_job, known_state=step_context.execution_plan.known_state, @@ -160,16 +158,29 @@ def _get_input_metadata( # load full input via input manager. DANGER, if value is excessively large this could have perf impacts on stephandler input_context = source.get_load_context(step_exec_context, input_def, manager_key) op_mutating_output: K8sOpMutatingOutput = input_manager.load_input(input_context) - step_context.log.debug(f"eagerly recieved input {op_mutating_output}") - self.step_input_cache[step_cache_key] = op_mutating_output.k8s_config - return self.step_input_cache[step_cache_key] + step_context.instance.report_engine_event( + f'eagerly using input "{input_def.name}" from upstream op "{source.step_output_handle.step_key}" as config input', + step_context.dagster_run, + EngineEventData( + metadata={ + USER_DEFINED_K8S_CONFIG_KEY: JsonMetadataValue( + data=op_mutating_output.k8s_config.to_dict() + ) + } + ), + step_key=step_context.step.key, + run_id=step_context.run_id, + ) + return op_mutating_output.k8s_config def _resolve_input_configs( self, step_handler_context: StepHandlerContext ) -> Optional[K8sContainerContext]: """Fetch all the configured k8s metadata for op inputs.""" step_key = self._get_step_key(step_handler_context) - step_context = step_handler_context.get_step_context(step_key) + step_context = cast( + StepOrchestrationContext, step_handler_context.get_step_context(step_key) + ) container_context = None for step_input in step_context.step.step_inputs: op_metadata_config = self._get_input_metadata(step_input, step_context) @@ -185,21 +196,30 @@ def _resolve_input_configs( def _get_container_context( self, step_handler_context: StepHandlerContext ) -> K8sContainerContext: + # function should be safe to cache since it's idempotent + step_key = self._get_step_key(step_handler_context) + if step_key in self.container_ctx_cache: + return self.container_ctx_cache[step_key] context = super()._get_container_context(step_handler_context) + self.container_ctx_cache[step_key] = context if not self.op_mutation_enabled: - return context + return self.container_ctx_cache[step_key] + # only use cache when op mutation is enabled. Fallback to K8sStephandler otherwise. merged_input_k8s_configs = self._resolve_input_configs(step_handler_context) if merged_input_k8s_configs: - return context.merge(merged_input_k8s_configs) - return context + self.container_ctx_cache[step_key] = self.container_ctx_cache[step_key].merge( + merged_input_k8s_configs + ) + return self.container_ctx_cache[step_key] - def terminate_step(self, step_handler_context): - yield from super().terminate_step(step_handler_context) - # pop cache to save mem for steps we won't visit again - step_key = self._get_step_key(step_handler_context) - stale_cache_keys = [k for k in self.step_input_cache if k.step_key == step_key] - for k in stale_cache_keys: - self.step_input_cache.pop(k) + def terminate_step(self, step_handler_context: StepHandlerContext): + try: + yield from super().terminate_step(step_handler_context) + finally: + # pop cache to save mem for steps we won't visit again + step_key = self._get_step_key(step_handler_context) + # entry might not exist if op mutation is disabled + self.container_ctx_cache.pop(step_key, None) @executor( @@ -237,8 +257,7 @@ def k8s_op_mutating_executor(init_context: InitExecutorContext) -> Executor: kubeconfig_file = cast(Optional[str], exc_cfg["kubeconfig_file"]) else: kubeconfig_file = run_launcher.kubeconfig_file if run_launcher else None - op_mutation_enabled = exc_cfg.get("op_mutation_enabled", False) - check.bool_param(op_mutation_enabled, "op_mutation_enabled") + op_mutation_enabled = check.bool_elem(exc_cfg, "op_mutation_enabled") return StepDelegatingExecutor( K8sMutatingStepHandler( image=exc_cfg.get("job_image"), # type: ignore diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py index 61b2d704d5e68..0b6f13aab4387 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py @@ -1,14 +1,15 @@ -from typing import List +from typing import List, Sequence, cast from unittest import mock from dagster import ( DagsterInstance, DagsterRun, + DynamicOut, + DynamicOutput, Executor, InitExecutorContext, InMemoryIOManager, IOManagerDefinition, - JobDefinition, OpExecutionContext, StepExecutionContext, io_manager, @@ -21,6 +22,8 @@ from dagster._core.execution.api import create_execution_plan from dagster._core.execution.context.system import PlanData, PlanOrchestrationContext from dagster._core.execution.context_creation_job import create_context_free_log_manager +from dagster._core.execution.plan.state import KnownExecutionState +from dagster._core.execution.plan.step import ExecutionStep from dagster._core.execution.retries import RetryMode from dagster._core.executor.step_delegating import StepHandlerContext from dagster._grpc.types import ExecuteStepArgs @@ -52,20 +55,51 @@ def producer() -> K8sOpMutatingOutput: return MOCK_K8s_OUTPUT @op - def sink(context: OpExecutionContext, producer: K8sOpMutatingOutput): - context.log.info(f"received the following input: {producer}") + def sink(context: OpExecutionContext, producer: K8sOpMutatingOutput) -> KnownExecutionState: + _ = producer + return context.get_step_execution_context().get_known_state() sink(producer()) +@job +def dynamic_producer_consumer(): + @op(out=DynamicOut(K8sOpMutatingOutput)) + def dyn_producer(): + for i in [3, 4]: + k8s_out = K8sOpMutatingOutput( + { + "container_config": { + "resources": { + "requests": {"cpu": f"{i}234m", "memory": f"{i}Gi"}, + "limits": {"cpu": f"{i}765m", "memory": f"{i}Gi"}, + } + } + } + ) + yield DynamicOutput(k8s_out, str(i)) + + @op + def dyn_sink(context: OpExecutionContext, producer: K8sOpMutatingOutput) -> KnownExecutionState: + context.log.info(f"received the following input: {producer}") + # hacky way for me to get known context -> InMemIOManager -> StepOrchestrationContext + # for step handler testing + return context.get_step_execution_context().get_known_state() + + dyn_producer().map(dyn_sink).collect() + + def _fetch_step_handler_context( job_def: ReconstructableJob, dagster_run: DagsterRun, instance: DagsterInstance, executor: Executor, steps: List[str], + known_state=None, ): - execution_plan = create_execution_plan(job_def) + execution_plan = create_execution_plan( + job_def, known_state=known_state, instance_ref=instance.get_ref() + ) log_manager = create_context_free_log_manager(instance, dagster_run) plan_context = PlanOrchestrationContext( @@ -92,14 +126,12 @@ def _fetch_step_handler_context( return StepHandlerContext( instance=instance, plan_context=plan_context, - steps=execution_plan.steps, + steps=cast(Sequence[ExecutionStep], execution_plan.steps), execute_step_args=execute_step_args, ) -def _fetch_mutating_executor( - instance: DagsterInstance, job_def: JobDefinition, executor_config=None -): +def _fetch_mutating_executor(instance, job_def, executor_config=None): process_result = process_config( resolve_to_config_type(_K8S_OP_EXECUTOR_CONFIG_SCHEMA), executor_config or {}, @@ -166,11 +198,76 @@ def test_mutating_step_handler_runtime_override( == MOCK_RUNTIME_RESOURCE_CONF["resources"] ) # check cache state - assert len(handler.step_input_cache) == 1 + assert len(handler.container_ctx_cache) == 1 # no need to stub execution context again as the step_input_cache should now hit. If not below will fail. list(handler.terminate_step(step_handler_ctx)) # ensure clean events are handled - assert not handler.step_input_cache + assert not handler.container_ctx_cache + + +def test_mutating_step_handler_dynamic_runtime_override( + k8s_instance: DagsterInstance, kubeconfig_file: str +): + """Using the `dynamic_producer_consumer` job, validate container context changes with respect to runtime dynamic outputs. + + We do this by executing the job in memory as normal, we then construct a StepOrchestration context by pulling + various state from the in memory job execution and reconstructing the orchestration context by hand. + The constructed orchestration context should be representative of what the step orchestration context and known state + would be when the actual job was ran. + """ + mock_k8s_client_batch_api = mock.MagicMock() + run_id = "de07af8f-d5f4-4a43-b545-132c3310999d" + shared_mem_io_manager = InMemoryIOManager() + io_manager_def = _make_shared_mem_io_manager(shared_mem_io_manager) + result = dynamic_producer_consumer.execute_in_process( + instance=k8s_instance, run_id=run_id, resources={"io_manager": io_manager_def} + ) + assert result.success + # for each mapping output, check the container context is propagated properly + for i in [3, 4]: + dyn_known_state = result.output_for_node("dyn_sink")[str(i)] + recon_job = reconstructable(dynamic_producer_consumer) + executor = _fetch_mutating_executor(k8s_instance, recon_job) + step_handler_ctx = _fetch_step_handler_context( + recon_job, + result.dagster_run, + k8s_instance, + executor, + [f"dyn_sink[{i}]"], + dyn_known_state, + ) + handler = K8sMutatingStepHandler( + image="bizbuz", + container_context=K8sContainerContext( + namespace="foo", + resources={ + "requests": {"cpu": "128m", "memory": "64Mi"}, + "limits": {"cpu": "500m", "memory": "1000Mi"}, + }, + ), + load_incluster_config=False, + kubeconfig_file=kubeconfig_file, + k8s_client_batch_api=mock_k8s_client_batch_api, + op_mutation_enabled=True, + ) + # stub api client + handler._api_client = mock.MagicMock() # noqa: SLF001 + with mock.patch.object( + StepExecutionContext, + "get_io_manager", + mock.MagicMock(return_value=shared_mem_io_manager), + ): + runtime_mutated_context = handler._get_container_context(step_handler_ctx) # noqa: SLF001 + assert runtime_mutated_context.run_k8s_config.container_config.get("resources") == { + "requests": {"cpu": f"{i}234m", "memory": f"{i}Gi"}, + "limits": {"cpu": f"{i}765m", "memory": f"{i}Gi"}, + } + # check cache state + assert len(handler.container_ctx_cache) == 1 + # no need to stub execution context again as the step_input_cache should now hit. If not below will fail. + list(handler.terminate_step(step_handler_ctx)) + # ensure clean events are handled + assert not handler.container_ctx_cache def test_mutating_step_handler_no_runtime_override(k8s_instance: DagsterInstance, kubeconfig_file): @@ -203,4 +300,7 @@ def test_mutating_step_handler_no_runtime_override(k8s_instance: DagsterInstance runtime_mutated_context.run_k8s_config.container_config.get("resources") == initial_resources ) - assert not handler.step_input_cache + handler._api_client = mock.Mock() # noqa: SLF001 + assert len(handler.container_ctx_cache) == 1 + list(handler.terminate_step(step_handler_ctx)) + assert not handler.container_ctx_cache