diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py index b6144b65296e9..347e392cc8c1a 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/executor.py @@ -1,7 +1,8 @@ -from typing import Iterator, List, Optional, cast +from typing import Iterator, List, Literal, Optional, cast import kubernetes.config from dagster import ( + Config, Field, IntSource, Noneable, @@ -12,6 +13,8 @@ from dagster._core.definitions.executor_definition import multiple_process_executor_requirements from dagster._core.definitions.metadata import MetadataValue from dagster._core.events import DagsterEvent, EngineEventData +from dagster._core.execution.context.system import StepExecutionContext +from dagster._core.execution.plan.inputs import FromInputManager, FromStepOutput, StepInput from dagster._core.execution.retries import RetryMode, get_retries_config from dagster._core.execution.tags import get_tag_concurrency_limits_config from dagster._core.executor.base import Executor @@ -29,6 +32,7 @@ from .client import DagsterKubernetesClient from .container_context import K8sContainerContext from .job import ( + USER_DEFINED_K8S_CONFIG_KEY, USER_DEFINED_K8S_CONFIG_SCHEMA, DagsterK8sJobConfig, UserDefinedDagsterK8sConfig, @@ -341,3 +345,118 @@ def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[D ) self._api_client.delete_job(job_name=job_name, namespace=container_context.namespace) + + +K8S_OP_STRATEGY = Literal["all", "first", "select"] + + +class OpInputStrategy(Config): + strategy: K8S_OP_STRATEGY = Field( + str, + "first", + is_required=False, + description="the strategy with which we resolve multiple ops with k8s metadata", + ) + input_keys: list[str] = Field( + list[str], + default_value=None, + is_required=False, + description="given op strategy is 'select', the op inputs which are used combined to form the current ops config", + ) + + +K8S_OP_EXECUTOR_CONFIG_SCHEMA = merge_dicts( + _K8S_EXECUTOR_CONFIG_SCHEMA, + { + "input_strategy": Field( + Optional[OpInputStrategy], + default_value=None, + is_required=False, + description="how to consume output metadata for current op input", + ) + }, +) + + +class K8sOpStepHandler(K8sStepHandler): + """Specialized step handler that configure the next op based on the op metadata of the prior op.""" + + input_strategy: OpInputStrategy + + def __init__( + self, + image: str | None, + container_context: K8sContainerContext, + load_incluster_config: bool, + kubeconfig_file: str | None, + k8s_client_batch_api=None, + input_strategy: Optional[OpInputStrategy] = None, + ): + self.input_strategy = input_strategy + super().__init__( + image, container_context, load_incluster_config, kubeconfig_file, k8s_client_batch_api + ) + + def _get_input_metadata(self, op_input: StepInput, step_context: StepExecutionContext) -> dict: + input_def = step_context.op_def.input_def_named(op_input.name) + source = op_input.source + if isinstance(source, FromInputManager): + if input_def.metadata: + return input_def.metadata.get(USER_DEFINED_K8S_CONFIG_KEY, {}) + return {} + if isinstance(source, FromStepOutput): + if source.fan_in: + step_context.log.info("fan in step input metadata not supported") + return {} + upstream_output_handle = source.step_output_handle + output_name = upstream_output_handle.output_name + upstream_step = step_context.execution_plan.get_step_output(upstream_output_handle) + job_def = step_context.job_def + upstream_node = job_def.get_node(upstream_step.node_handle) + if not upstream_node.has_output(output_name): + step_context.log.error( + f"corresponding upstream {output_name} output source for input {op_input.name} not found" + ) + return {} + output_def = upstream_node.output_def_named(output_name) + if output_def.metadata: + return output_def.metadata.get(USER_DEFINED_K8S_CONFIG_KEY, {}) + return {} + + def _resolve_input_configs( + self, step_handler_context: StepHandlerContext + ) -> Optional[K8sContainerContext]: + """Fetch all the configured k8s metadata for op inputs.""" + step_key = self._get_step_key(step_handler_context) + step_context = step_handler_context.get_step_context(step_key) + container_context = None + for input_name, step_input in step_context.step.step_input_dict.items(): + if ( + self.input_strategy.strategy == "select" + and input_name not in self.input_strategy.input_keys + ): + continue + op_metadata_config = self._get_input_metadata(step_input, step_context) + if not op_metadata_config: + continue + k8s_context = K8sContainerContext.create_from_config(op_metadata_config) + if self.input_strategy.strategy == "first": + step_context.log.info(f"using config metadata from {input_name}") + step_context.log.debug(f"configure metadata {op_metadata_config}") + return k8s_context + if container_context is None: + container_context = k8s_context + else: + container_context.merge(k8s_context) + return container_context + + def _get_container_context( + self, step_handler_context: StepHandlerContext + ) -> K8sContainerContext: + context = super()._get_container_context(step_handler_context) + if not self.op_input_config: + return context + step_key = self._get_step_key(step_handler_context) + step_context = step_handler_context.get_step_context(step_key) + self._resolve_input_configs(step_context) + return context diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/job.py b/python_modules/libraries/dagster-k8s/dagster_k8s/job.py index 75acd32465ef5..98518549181bd 100644 --- a/python_modules/libraries/dagster-k8s/dagster_k8s/job.py +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/job.py @@ -195,7 +195,7 @@ def get_k8s_resource_requirements(tags: Mapping[str, str]): return result.value -def get_user_defined_k8s_config(tags: Mapping[str, str]): +def get_user_defined_k8s_config(tags: Mapping[str, str]) -> UserDefinedDagsterK8sConfig: check.mapping_param(tags, "tags", key_type=str, value_type=str) if not any(key in tags for key in [K8S_RESOURCE_REQUIREMENTS_KEY, USER_DEFINED_K8S_CONFIG_KEY]): diff --git a/python_modules/libraries/dagster-k8s/dagster_k8s/k8s_op_executor.py b/python_modules/libraries/dagster-k8s/dagster_k8s/k8s_op_executor.py new file mode 100644 index 0000000000000..acc3326a41522 --- /dev/null +++ b/python_modules/libraries/dagster-k8s/dagster_k8s/k8s_op_executor.py @@ -0,0 +1,156 @@ +from enum import Enum +from typing import Optional + +import dagster._check as check +from dagster import ( + Array, + DagsterEventType, + Enum as DagsterEnum, + Field, + String, +) +from dagster._core.execution.context.system import StepExecutionContext +from dagster._core.execution.plan.inputs import FromStepOutput, StepInput +from dagster._core.execution.plan.outputs import StepOutputData +from dagster._core.executor.step_delegating import StepHandlerContext +from dagster._core.storage.event_log import EventLogRecord, SqlEventLogStorage +from dagster._utils.merger import merge_dicts + +from dagster_k8s.container_context import K8sContainerContext +from dagster_k8s.executor import K8sStepHandler +from dagster_k8s.job import ( + USER_DEFINED_K8S_CONFIG_KEY, + USER_DEFINED_K8S_CONFIG_SCHEMA, + UserDefinedDagsterK8sConfig, +) + + +class InputStrategy(Enum): + none = "none" + all = "all" + first = "first" + select = "select" + + def is_select(self) -> bool: + return self.value == InputStrategy.select.value + + +StringList = Array(String) + +USER_DEFINED_K8S_OP_CONFIG_SCHEMA = merge_dicts( + USER_DEFINED_K8S_CONFIG_SCHEMA.fields, + { + { + "input_strategy": Field( + DagsterEnum.from_python_enum(InputStrategy), + default_value=InputStrategy.none, + is_required=False, + ), + "from_inputs": Field(StringList, is_required=False), + } + }, +) + + +class K8sOpStepHandler(K8sStepHandler): + """Specialized step handler that configure the next op based on the op metadata of the prior op.""" + + _step_to_container_context: dict[str, K8sContainerContext] + input_strategy: InputStrategy + + def __init__( + self, + image: str | None, + container_context: K8sContainerContext, + load_incluster_config: bool, + kubeconfig_file: str | None, + k8s_client_batch_api=None, + input_strategy: InputStrategy = InputStrategy.none, + ): + check.inst_param(input_strategy, "input_strategy", InputStrategy) + self.input_strategy = input_strategy + self._step_to_container_context = {} + super().__init__( + image, container_context, load_incluster_config, kubeconfig_file, k8s_client_batch_api + ) + + def _get_corresponding_output_config( + self, step_input: StepInput, step_context: StepExecutionContext + ) -> Optional[K8sContainerContext]: + source = step_input.source + if not isinstance(source, FromStepOutput): + return None + if source.fan_in: + return None + event_log = step_context.instance.event_log_storage + if not isinstance(event_log, SqlEventLogStorage): + return None + + def valid_record(record: EventLogRecord) -> bool: + log_entry = record.event_log_entry + if log_entry.step_key != step_context.step.key: + return False + if not log_entry.is_dagster_event: + return False + dagster_event = log_entry.get_dagster_event() + step_output_data = dagster_event.event_specific_data + if not isinstance(step_output_data, StepOutputData): + return False + if step_output_data.mapping_key: + # we don't accept dynamic outputs + return False + return step_output_data.step_output_handle == source.step_output_handle + + run_id = step_context.dagster_run.run_id + LIMIT = 100 + event_log_conn = event_log.get_records_for_run( + run_id, of_type=DagsterEventType.STEP_OUTPUT, limit=LIMIT + ) + records = [r for r in event_log_conn.records if valid_record(r)] + # we expect to only have 1 record corresponding to a step output + while not records and event_log_conn.has_more: + event_log_conn = event_log.get_records_for_run( + run_id, event_log_conn.cursor, DagsterEventType.STEP_OUTPUT, LIMIT + ) + records += [r for r in event_log_conn.records if valid_record(r)] + if not records: + return None + assert len(records) == 1, "unexpected amount of records from filter query" + output_record = records[0] + log_entry = output_record.event_log_entry + dagster_step_output_event = log_entry.get_dagster_event() + step_output_data: StepOutputData = dagster_step_output_event.event_specific_data + metadata = step_output_data.metadata + if not metadata or USER_DEFINED_K8S_CONFIG_KEY not in metadata: + return None + config_dict = metadata[USER_DEFINED_K8S_CONFIG_KEY].value + if not isinstance(config_dict, dict): + step_context.log.warning( + f"configured input {step_input.name} has metadata of unexpected type {type(config_dict)}" + ) + return None + output_config = UserDefinedDagsterK8sConfig.from_dict(config_dict) + return K8sContainerContext(run_k8s_config=output_config) + + def _get_combined_input_container_context( + self, step_handler_context: StepHandlerContext + ) -> K8sContainerContext: + step_key = self._get_step_key(step_handler_context) + step_context = step_handler_context.get_step_context(step_key) + viable_inputs = step_context.step.step_input_dict + for step_input in viable_inputs.values(): + self._get_corresponding_output_config(step_input) + + def _get_container_context( + self, step_handler_context: StepHandlerContext + ) -> K8sContainerContext: + step_key = self._get_step_key(step_handler_context) + if step_key in self._step_to_container_context: + return self._step_to_container_context[step_key] + container_context = super()._get_container_context(step_handler_context) + self._step_to_container_context[step_key] = container_context + if self.input_strategy == InputStrategy.none: + return container_context + # iterate over all inputs and fetch their corresponding output metadata as launch configs + + return container_context