diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index daffa21caadb1..82757328af5da 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -237,9 +237,10 @@ def _get_container_context( user_defined_k8s_config = get_user_defined_k8s_config( step_handler_context.step_tags[step_key] ) - + step_context = step_handler_context.get_step_context(step_key) + op_name = step_context.step.op_name per_op_override = UserDefinedDagsterK8sConfig.from_dict( - self._per_step_k8s_config.get(step_key, {}) + self._per_step_k8s_config.get(op_name, {}) ) return context.merge(K8sContainerContext(run_k8s_config=user_defined_k8s_config)).merge( diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py index e5dead96d63fe..d6ea6d8b740fc 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s_tests/unit_tests/test_executor.py @@ -1,13 +1,23 @@ import json +from typing import Optional from unittest import mock import pytest -from dagster import job, op, repository +from dagster import ( + DagsterInstance, + DynamicOut, + DynamicOutput, + OpExecutionContext, + job, + op, + repository, +) from dagster._config import process_config, resolve_to_config_type from dagster._core.definitions.reconstruct import reconstructable from dagster._core.execution.api import create_execution_plan from dagster._core.execution.context.system import PlanData, PlanOrchestrationContext from dagster._core.execution.context_creation_job import create_context_free_log_manager +from dagster._core.execution.plan.state import KnownExecutionState from dagster._core.execution.retries import RetryMode from dagster._core.executor.init import InitExecutorContext from dagster._core.executor.step_delegating.step_handler.base import StepHandlerContext @@ -115,6 +125,26 @@ def foo(): foo() +@job +def dynamic_producer_consumer_job(): + @op(out=DynamicOut(int)) + def dyn_producer(): + for i in [3, 4]: + yield DynamicOutput( + i, + str(i), + ) + + @op + def dyn_sink(context: OpExecutionContext, producer: int) -> KnownExecutionState: + context.log.info(f"got input {producer}") + # hacky way for to get known context -> InMemIOManager -> StepOrchestrationContext + # for step handler testing + return context.get_step_execution_context().get_known_state() + + dyn_producer().map(dyn_sink).collect() + + @repository def bar_repo(): return [bar] @@ -211,8 +241,15 @@ def _get_executor(instance, job_def, executor_config=None): ) -def _step_handler_context(job_def, dagster_run, instance, executor): - execution_plan = create_execution_plan(job_def) +def _step_handler_context( + job_def, + dagster_run, + instance, + executor, + step: str = "foo", + known_state: Optional[KnownExecutionState] = None, +): + execution_plan = create_execution_plan(job_def, known_state=known_state) log_manager = create_context_free_log_manager(instance, dagster_run) plan_context = PlanOrchestrationContext( @@ -232,7 +269,8 @@ def _step_handler_context(job_def, dagster_run, instance, executor): execute_step_args = ExecuteStepArgs( reconstructable(bar).get_python_origin(), dagster_run.run_id, - ["foo"], + # note that k8s_job_executor can only execute one step at a time. + [step], print_serialized_events=False, ) @@ -734,3 +772,49 @@ def test_per_step_k8s_config(k8s_run_launcher_instance, python_origin_with_conta assert raw_k8s_config.container_config["resources"] == FOURTH_RESOURCES_TAGS assert raw_k8s_config.container_config["working_dir"] == "MY_WORKING_DIR" assert raw_k8s_config.container_config["volume_mounts"] == OTHER_VOLUME_MOUNTS_TAGS + + +def test_per_step_k8s_config_dynamic_job(k8s_run_launcher_instance: DagsterInstance): + run_id = "de07af8f-d5f4-4a43-b545-132c3310999d" + result = dynamic_producer_consumer_job.execute_in_process( + instance=k8s_run_launcher_instance, + run_id=run_id, + ) + assert result.success + recon_job = reconstructable(dynamic_producer_consumer_job) + executor = _get_executor( + k8s_run_launcher_instance, + recon_job, + { + "step_k8s_config": { # injected into every step + "container_config": { + "working_dir": "MY_WORKING_DIR", # set on every step + "resources": THIRD_RESOURCES_TAGS, # overridden by the per_step level, so ignored + } + }, + "per_step_k8s_config": { + "dyn_sink": { + "container_config": { + "resources": FOURTH_RESOURCES_TAGS, + } + } + }, + }, + ) + dynamic_step = "3" + dyn_known_state = result.output_for_node("dyn_sink")[dynamic_step] + step_handler_context = _step_handler_context( + recon_job, + result.dagster_run, + k8s_run_launcher_instance, + executor, + step=f"dyn_sink[{dynamic_step}]", + known_state=dyn_known_state, + ) + container_context = executor._step_handler._get_container_context( # noqa: SLF001 # pyright: ignore[reportAttributeAccessIssue] + step_handler_context + ) + raw_k8s_config = container_context.run_k8s_config + + assert raw_k8s_config.container_config["resources"] == FOURTH_RESOURCES_TAGS + assert raw_k8s_config.container_config["working_dir"] == "MY_WORKING_DIR"