diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py index db01f1881087e..d68a18fd38f80 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py @@ -70,6 +70,7 @@ def type_check(cls, _context, value): class StepInputCacheKey(NamedTuple): + run_id: str step_key: str input_name: str @@ -101,7 +102,9 @@ def _get_input_metadata( # dagster type names should be garunteed unique, so this should be safe if step_input.dagster_type_key != K8sOpMutatingOutput.__name__: return {} - step_cache_key = StepInputCacheKey(step_context.step.key, step_input.name) + step_cache_key = StepInputCacheKey( + step_context.run_id, step_context.step.key, step_input.name + ) if step_cache_key in self.step_input_cache: step_context.log.debug(f"cache hit for key {step_cache_key}") return self.step_input_cache[step_cache_key] diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py index fd0cae990686a..61b2d704d5e68 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_mutating_executor.py @@ -6,15 +6,18 @@ DagsterRun, Executor, InitExecutorContext, - IOManager, + InMemoryIOManager, + IOManagerDefinition, JobDefinition, OpExecutionContext, StepExecutionContext, + io_manager, job, op, reconstructable, ) from dagster._config import process_config, resolve_to_config_type +from dagster._core.definitions import ReconstructableJob from dagster._core.execution.api import create_execution_plan from dagster._core.execution.context.system import PlanData, PlanOrchestrationContext from dagster._core.execution.context_creation_job import create_context_free_log_manager @@ -42,13 +45,6 @@ MOCK_K8s_OUTPUT = K8sOpMutatingOutput({"container_config": MOCK_RUNTIME_RESOURCE_CONF}) -class MockIOManager(IOManager): - def load_input(self, context): - return MOCK_K8s_OUTPUT - - def handle_output(self, context, obj): ... - - @job def simple_producer_consumer(): @op @@ -63,7 +59,7 @@ def sink(context: OpExecutionContext, producer: K8sOpMutatingOutput): def _fetch_step_handler_context( - job_def: JobDefinition, + job_def: ReconstructableJob, dagster_run: DagsterRun, instance: DagsterInstance, executor: Executor, @@ -120,12 +116,25 @@ def _fetch_mutating_executor( ) +def _make_shared_mem_io_manager(inmem_io_manager: InMemoryIOManager) -> IOManagerDefinition: + @io_manager + def shared_mem_io_manager(_) -> InMemoryIOManager: + return inmem_io_manager + + return shared_mem_io_manager + + def test_mutating_step_handler_runtime_override( - k8s_instance: DagsterInstance, tmp_path, kubeconfig_file + k8s_instance: DagsterInstance, kubeconfig_file: str ): + """Using the `simple_producer_consumer` job, ensure that a simple output can be detected, eagerly loaded, and consumed as container context.""" mock_k8s_client_batch_api = mock.MagicMock() - # io_manager = mem_io_manager - result = simple_producer_consumer.execute_in_process(instance=k8s_instance) + run_id = "de07af8f-d5f4-4a43-b545-132c3310999d" + shared_mem_io_manager = InMemoryIOManager() + io_manager_def = _make_shared_mem_io_manager(shared_mem_io_manager) + result = simple_producer_consumer.execute_in_process( + instance=k8s_instance, run_id=run_id, resources={"io_manager": io_manager_def} + ) assert result.success recon_job = reconstructable(simple_producer_consumer) executor = _fetch_mutating_executor(k8s_instance, recon_job) @@ -146,21 +155,27 @@ def test_mutating_step_handler_runtime_override( k8s_client_batch_api=mock_k8s_client_batch_api, op_mutation_enabled=True, ) + # stub api client + handler._api_client = mock.MagicMock() # noqa: SLF001 with mock.patch.object( - StepExecutionContext, "get_io_manager", mock.MagicMock(return_value=MockIOManager()) + StepExecutionContext, "get_io_manager", mock.MagicMock(return_value=shared_mem_io_manager) ): runtime_mutated_context = handler._get_container_context(step_handler_ctx) # noqa: SLF001 assert ( runtime_mutated_context.run_k8s_config.container_config.get("resources") == MOCK_RUNTIME_RESOURCE_CONF["resources"] ) + # check cache state + assert len(handler.step_input_cache) == 1 + # no need to stub execution context again as the step_input_cache should now hit. If not below will fail. + list(handler.terminate_step(step_handler_ctx)) + # ensure clean events are handled + assert not handler.step_input_cache -def test_mutating_step_handler_no_runtime_override( - k8s_instance: DagsterInstance, tmp_path, kubeconfig_file -): +def test_mutating_step_handler_no_runtime_override(k8s_instance: DagsterInstance, kubeconfig_file): + """Ensure that when disabled, we fallback to the behavior of the K8sStepHandler.""" mock_k8s_client_batch_api = mock.MagicMock() - # io_manager = mem_io_manager result = simple_producer_consumer.execute_in_process(instance=k8s_instance) assert result.success recon_job = reconstructable(simple_producer_consumer) @@ -168,25 +183,24 @@ def test_mutating_step_handler_no_runtime_override( step_handler_ctx = _fetch_step_handler_context( recon_job, result.dagster_run, k8s_instance, executor, ["sink"] ) + initial_resources = { + "requests": {"cpu": "128m", "memory": "64Mi"}, + "limits": {"cpu": "500m", "memory": "1000Mi"}, + } handler = K8sMutatingStepHandler( image="bizbuz", container_context=K8sContainerContext( namespace="foo", - resources={ - "requests": {"cpu": "128m", "memory": "64Mi"}, - "limits": {"cpu": "500m", "memory": "1000Mi"}, - }, + resources=initial_resources, ), load_incluster_config=False, kubeconfig_file=kubeconfig_file, k8s_client_batch_api=mock_k8s_client_batch_api, op_mutation_enabled=False, ) - with mock.patch.object( - StepExecutionContext, "get_io_manager", mock.MagicMock(return_value=MockIOManager()) - ): - runtime_mutated_context = handler._get_container_context(step_handler_ctx) # noqa: SLF001 - assert runtime_mutated_context.run_k8s_config.container_config.get("resources") == { - "requests": {"cpu": "128m", "memory": "64Mi"}, - "limits": {"cpu": "500m", "memory": "1000Mi"}, - } + runtime_mutated_context = handler._get_container_context(step_handler_ctx) # noqa: SLF001 + assert ( + runtime_mutated_context.run_k8s_config.container_config.get("resources") + == initial_resources + ) + assert not handler.step_input_cache