Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: stephandler draft #1

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 120 additions & 1 deletion python_modules/libraries/dagster-k8s/dagster_k8s/executor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Iterator, List, Optional, cast
from typing import Iterator, List, Literal, Optional, cast

import kubernetes.config
from dagster import (
Config,
Field,
IntSource,
Noneable,
Expand All @@ -12,6 +13,8 @@
from dagster._core.definitions.executor_definition import multiple_process_executor_requirements
from dagster._core.definitions.metadata import MetadataValue
from dagster._core.events import DagsterEvent, EngineEventData
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.execution.plan.inputs import FromInputManager, FromStepOutput, StepInput
from dagster._core.execution.retries import RetryMode, get_retries_config
from dagster._core.execution.tags import get_tag_concurrency_limits_config
from dagster._core.executor.base import Executor
Expand All @@ -29,6 +32,7 @@
from .client import DagsterKubernetesClient
from .container_context import K8sContainerContext
from .job import (
USER_DEFINED_K8S_CONFIG_KEY,
USER_DEFINED_K8S_CONFIG_SCHEMA,
DagsterK8sJobConfig,
UserDefinedDagsterK8sConfig,
Expand Down Expand Up @@ -341,3 +345,118 @@ def terminate_step(self, step_handler_context: StepHandlerContext) -> Iterator[D
)

self._api_client.delete_job(job_name=job_name, namespace=container_context.namespace)


K8S_OP_STRATEGY = Literal["all", "first", "select"]


class OpInputStrategy(Config):
strategy: K8S_OP_STRATEGY = Field(
str,
"first",
is_required=False,
description="the strategy with which we resolve multiple ops with k8s metadata",
)
input_keys: list[str] = Field(
list[str],
default_value=None,
is_required=False,
description="given op strategy is 'select', the op inputs which are used combined to form the current ops config",
)


K8S_OP_EXECUTOR_CONFIG_SCHEMA = merge_dicts(
_K8S_EXECUTOR_CONFIG_SCHEMA,
{
"input_strategy": Field(
Optional[OpInputStrategy],
default_value=None,
is_required=False,
description="how to consume output metadata for current op input",
)
},
)


class K8sOpStepHandler(K8sStepHandler):
"""Specialized step handler that configure the next op based on the op metadata of the prior op."""

input_strategy: OpInputStrategy

def __init__(
self,
image: str | None,
container_context: K8sContainerContext,
load_incluster_config: bool,
kubeconfig_file: str | None,
k8s_client_batch_api=None,
input_strategy: Optional[OpInputStrategy] = None,
):
self.input_strategy = input_strategy
super().__init__(
image, container_context, load_incluster_config, kubeconfig_file, k8s_client_batch_api
)

def _get_input_metadata(self, op_input: StepInput, step_context: StepExecutionContext) -> dict:
input_def = step_context.op_def.input_def_named(op_input.name)
source = op_input.source
if isinstance(source, FromInputManager):
if input_def.metadata:
return input_def.metadata.get(USER_DEFINED_K8S_CONFIG_KEY, {})
return {}
if isinstance(source, FromStepOutput):
if source.fan_in:
step_context.log.info("fan in step input metadata not supported")
return {}
upstream_output_handle = source.step_output_handle
output_name = upstream_output_handle.output_name
upstream_step = step_context.execution_plan.get_step_output(upstream_output_handle)
job_def = step_context.job_def
upstream_node = job_def.get_node(upstream_step.node_handle)
if not upstream_node.has_output(output_name):
step_context.log.error(
f"corresponding upstream {output_name} output source for input {op_input.name} not found"
)
return {}
output_def = upstream_node.output_def_named(output_name)
if output_def.metadata:
return output_def.metadata.get(USER_DEFINED_K8S_CONFIG_KEY, {})
return {}

def _resolve_input_configs(
self, step_handler_context: StepHandlerContext
) -> Optional[K8sContainerContext]:
"""Fetch all the configured k8s metadata for op inputs."""
step_key = self._get_step_key(step_handler_context)
step_context = step_handler_context.get_step_context(step_key)
container_context = None
for input_name, step_input in step_context.step.step_input_dict.items():
if (
self.input_strategy.strategy == "select"
and input_name not in self.input_strategy.input_keys
):
continue
op_metadata_config = self._get_input_metadata(step_input, step_context)
if not op_metadata_config:
continue
k8s_context = K8sContainerContext.create_from_config(op_metadata_config)
if self.input_strategy.strategy == "first":
step_context.log.info(f"using config metadata from {input_name}")
step_context.log.debug(f"configure metadata {op_metadata_config}")
return k8s_context
if container_context is None:
container_context = k8s_context
else:
container_context.merge(k8s_context)
return container_context

def _get_container_context(
self, step_handler_context: StepHandlerContext
) -> K8sContainerContext:
context = super()._get_container_context(step_handler_context)
if not self.op_input_config:
return context
step_key = self._get_step_key(step_handler_context)
step_context = step_handler_context.get_step_context(step_key)
self._resolve_input_configs(step_context)
return context
2 changes: 1 addition & 1 deletion python_modules/libraries/dagster-k8s/dagster_k8s/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_k8s_resource_requirements(tags: Mapping[str, str]):
return result.value


def get_user_defined_k8s_config(tags: Mapping[str, str]):
def get_user_defined_k8s_config(tags: Mapping[str, str]) -> UserDefinedDagsterK8sConfig:
check.mapping_param(tags, "tags", key_type=str, value_type=str)

if not any(key in tags for key in [K8S_RESOURCE_REQUIREMENTS_KEY, USER_DEFINED_K8S_CONFIG_KEY]):
Expand Down
156 changes: 156 additions & 0 deletions python_modules/libraries/dagster-k8s/dagster_k8s/k8s_op_executor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from enum import Enum
from typing import Optional

import dagster._check as check
from dagster import (
Array,
DagsterEventType,
Enum as DagsterEnum,
Field,
String,
)
from dagster._core.execution.context.system import StepExecutionContext
from dagster._core.execution.plan.inputs import FromStepOutput, StepInput
from dagster._core.execution.plan.outputs import StepOutputData
from dagster._core.executor.step_delegating import StepHandlerContext
from dagster._core.storage.event_log import EventLogRecord, SqlEventLogStorage
from dagster._utils.merger import merge_dicts

from dagster_k8s.container_context import K8sContainerContext
from dagster_k8s.executor import K8sStepHandler
from dagster_k8s.job import (
USER_DEFINED_K8S_CONFIG_KEY,
USER_DEFINED_K8S_CONFIG_SCHEMA,
UserDefinedDagsterK8sConfig,
)


class InputStrategy(Enum):
none = "none"
all = "all"
first = "first"
select = "select"

def is_select(self) -> bool:
return self.value == InputStrategy.select.value


StringList = Array(String)

USER_DEFINED_K8S_OP_CONFIG_SCHEMA = merge_dicts(
USER_DEFINED_K8S_CONFIG_SCHEMA.fields,
{
{
"input_strategy": Field(
DagsterEnum.from_python_enum(InputStrategy),
default_value=InputStrategy.none,
is_required=False,
),
"from_inputs": Field(StringList, is_required=False),
}
},
)


class K8sOpStepHandler(K8sStepHandler):
"""Specialized step handler that configure the next op based on the op metadata of the prior op."""

_step_to_container_context: dict[str, K8sContainerContext]
input_strategy: InputStrategy

def __init__(
self,
image: str | None,
container_context: K8sContainerContext,
load_incluster_config: bool,
kubeconfig_file: str | None,
k8s_client_batch_api=None,
input_strategy: InputStrategy = InputStrategy.none,
):
check.inst_param(input_strategy, "input_strategy", InputStrategy)
self.input_strategy = input_strategy
self._step_to_container_context = {}
super().__init__(
image, container_context, load_incluster_config, kubeconfig_file, k8s_client_batch_api
)

def _get_corresponding_output_config(
self, step_input: StepInput, step_context: StepExecutionContext
) -> Optional[K8sContainerContext]:
source = step_input.source
if not isinstance(source, FromStepOutput):
return None
if source.fan_in:
return None
event_log = step_context.instance.event_log_storage
if not isinstance(event_log, SqlEventLogStorage):
return None

def valid_record(record: EventLogRecord) -> bool:
log_entry = record.event_log_entry
if log_entry.step_key != step_context.step.key:
return False
if not log_entry.is_dagster_event:
return False
dagster_event = log_entry.get_dagster_event()
step_output_data = dagster_event.event_specific_data
if not isinstance(step_output_data, StepOutputData):
return False
if step_output_data.mapping_key:
# we don't accept dynamic outputs
return False
return step_output_data.step_output_handle == source.step_output_handle

run_id = step_context.dagster_run.run_id
LIMIT = 100
event_log_conn = event_log.get_records_for_run(
run_id, of_type=DagsterEventType.STEP_OUTPUT, limit=LIMIT
)
records = [r for r in event_log_conn.records if valid_record(r)]
# we expect to only have 1 record corresponding to a step output
while not records and event_log_conn.has_more:
event_log_conn = event_log.get_records_for_run(
run_id, event_log_conn.cursor, DagsterEventType.STEP_OUTPUT, LIMIT
)
records += [r for r in event_log_conn.records if valid_record(r)]
if not records:
return None
assert len(records) == 1, "unexpected amount of records from filter query"
output_record = records[0]
log_entry = output_record.event_log_entry
dagster_step_output_event = log_entry.get_dagster_event()
step_output_data: StepOutputData = dagster_step_output_event.event_specific_data
metadata = step_output_data.metadata
if not metadata or USER_DEFINED_K8S_CONFIG_KEY not in metadata:
return None
config_dict = metadata[USER_DEFINED_K8S_CONFIG_KEY].value
if not isinstance(config_dict, dict):
step_context.log.warning(
f"configured input {step_input.name} has metadata of unexpected type {type(config_dict)}"
)
return None
output_config = UserDefinedDagsterK8sConfig.from_dict(config_dict)
return K8sContainerContext(run_k8s_config=output_config)

def _get_combined_input_container_context(
self, step_handler_context: StepHandlerContext
) -> K8sContainerContext:
step_key = self._get_step_key(step_handler_context)
step_context = step_handler_context.get_step_context(step_key)
viable_inputs = step_context.step.step_input_dict
for step_input in viable_inputs.values():
self._get_corresponding_output_config(step_input)

def _get_container_context(
self, step_handler_context: StepHandlerContext
) -> K8sContainerContext:
step_key = self._get_step_key(step_handler_context)
if step_key in self._step_to_container_context:
return self._step_to_container_context[step_key]
container_context = super()._get_container_context(step_handler_context)
self._step_to_container_context[step_key] = container_context
if self.input_strategy == InputStrategy.none:
return container_context
# iterate over all inputs and fetch their corresponding output metadata as launch configs

return container_context