Skip to content

Commit

Permalink
output metadata based step mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
abhinavDhulipala committed Nov 9, 2024
1 parent 017187b commit 0648263
Show file tree
Hide file tree
Showing 2 changed files with 449 additions and 215 deletions.
279 changes: 172 additions & 107 deletions python_modules/libraries/dagster-k8s/dagster_k8s/op_mutating_executor.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,35 @@
from typing import Any, Mapping, Optional, Union, cast
import logging
from typing import Dict, Generic, Optional, TypeVar, Union, cast

import dagster._seven as seven
from dagster import (
DagsterEventType,
Field,
IOManager,
JsonMetadataValue,
StepRunRef,
TypeCheck,
_check as check,
executor,
usable_as_dagster_type,
)
from dagster._core.definitions.executor_definition import multiple_process_executor_requirements
from dagster._core.events import EngineEventData
from dagster._core.events.log import EventLogEntry
from dagster._core.execution.context.system import StepOrchestrationContext
from dagster._core.execution.plan.external_step import step_run_ref_to_step_context
from dagster._core.execution.plan.handle import ResolvedFromDynamicStepHandle, StepHandle
from dagster._core.execution.plan.inputs import (
FromPendingDynamicStepOutput,
FromStepOutput,
StepInput,
)
from dagster._core.execution.plan.outputs import StepOutputData
from dagster._core.execution.retries import RetryMode
from dagster._core.executor.base import Executor
from dagster._core.executor.init import InitExecutorContext
from dagster._core.executor.step_delegating import StepDelegatingExecutor, StepHandlerContext
from dagster._core.storage.input_manager import InputManager
from dagster._core.storage.event_log import EventLogRecord, SqlEventLogStorage
from dagster._core.storage.event_log.base import EventLogConnection, EventLogCursor
from dagster._core.storage.event_log.schema import SqlEventLogStorageTable
from dagster._core.storage.sqlalchemy_compat import db_select
from dagster._serdes import deserialize_value
from dagster._serdes.errors import DeserializationError
from dagster._utils.merger import merge_dicts

from dagster_k8s.container_context import K8sContainerContext
Expand All @@ -43,41 +49,105 @@
)


USER_DEFINED_INPUT_K8S_OP_MUTATION_KEY = "dagster-k8s/mutation-enabled"

T = TypeVar("T")


@usable_as_dagster_type
class K8sOpMutatingOutput:
k8s_config: UserDefinedDagsterK8sConfig
value: Any
class K8sOpMutatingWrapper(Generic[T]):
value: T

def __init__(
self, k8s_config: Union[UserDefinedDagsterK8sConfig, Mapping[str, Any]], value: Any = None
):
if isinstance(k8s_config, dict):
check.mapping_param(k8s_config, "k8s_config", str)
k8s_config = UserDefinedDagsterK8sConfig.from_dict(k8s_config)
else:
check.inst_param(k8s_config, "k8s_config", UserDefinedDagsterK8sConfig)
self.k8s_config = cast(UserDefinedDagsterK8sConfig, k8s_config)
def __init__(self, value: T):
self.value = value

@classmethod
def type_check(cls, _context, value: "K8sOpMutatingOutput"):
_ = _context
if not isinstance(value, cls):
return TypeCheck(
success=False, description=f"expected wrapper class {type(K8sOpMutatingOutput)}"
def name(cls) -> str:
return cls.__name__


def get_output_records_for_run_step(
sql_event_log: SqlEventLogStorage,
run_id: str,
step_key: str,
cursor: Optional[str] = None,
limit: int = 0,
ascending: bool = True,
) -> EventLogConnection:
"""Minor modification to SqlEventLogStorage#get_records_for_run where we query only STEP_OUPUT's for a particular run's & step key."""
check.inst_param(sql_event_log, "sql_event_log", SqlEventLogStorage)
check.str_param(run_id, "run_id")
check.str_param(step_key, "step_key")
check.int_param(limit, "limit")
check.invariant(limit >= 0, "provided limit must be greater than 0")
check.opt_str_param(cursor, "cursor")

query = (
db_select([SqlEventLogStorageTable.c.id, SqlEventLogStorageTable.c.event])
.where(
(SqlEventLogStorageTable.c.run_id == run_id)
& (SqlEventLogStorageTable.c.step_key == step_key)
& (SqlEventLogStorageTable.c.dagster_event_type == DagsterEventType.STEP_OUTPUT.value)
)
.order_by(
SqlEventLogStorageTable.c.id.asc() if ascending else SqlEventLogStorageTable.c.id.desc()
)
)

# adjust 0 based index cursor to SQL offset
if cursor is not None:
cursor_obj = EventLogCursor.parse(cursor)
if cursor_obj.is_offset_cursor():
query = query.offset(cursor_obj.offset())
elif cursor_obj.is_id_cursor():
if ascending:
query = query.where(SqlEventLogStorageTable.c.id > cursor_obj.storage_id())
else:
query = query.where(SqlEventLogStorageTable.c.id < cursor_obj.storage_id())
if limit:
query = query.limit(limit)

with sql_event_log.run_connection(run_id) as conn:
results = conn.execute(query).fetchall()

last_record_id = None
try:
records = []
for (
record_id,
json_str,
) in results:
records.append(
EventLogRecord(
storage_id=record_id,
event_log_entry=deserialize_value(json_str, EventLogEntry),
)
)
return TypeCheck(success=True)
last_record_id = record_id
except (seven.JSONDecodeError, DeserializationError) as err:
logging.warning(f"failed to parse event log {record_id} due to {err}")

if last_record_id is not None:
next_cursor = EventLogCursor.from_storage_id(last_record_id).to_string()
elif cursor:
# record fetch returned no new logs, return the same cursor
next_cursor = cursor
else:
# rely on the fact that all storage ids will be positive integers
next_cursor = EventLogCursor.from_storage_id(-1).to_string()

return EventLogConnection(
records=records,
cursor=next_cursor,
has_more=bool(limit and len(results) == limit),
)


class K8sMutatingStepHandler(K8sStepHandler):
"""Specialized step handler that configure the next op based on the op metadata of the prior op."""

op_mutation_enabled: bool
# cache for container contexts
container_ctx_cache: dict[str, K8sContainerContext]
# consider adding a step output cache. Helps for reused op outputs, but could potentially
# store outputs that are never used again. Maybe LRU cache?
# step_output_cache: OrderedDict[StepInputCacheKey, UserDefinedDagsterK8sConfig]
container_ctx_cache: Dict[Union[StepHandle, ResolvedFromDynamicStepHandle], K8sContainerContext]

def __init__(
self,
Expand All @@ -101,9 +171,25 @@ def name(self) -> str:
def _get_input_metadata(
self, step_input: StepInput, step_context: StepOrchestrationContext
) -> Optional[UserDefinedDagsterK8sConfig]:
# dagster type names should be garunteed unique, so this should be safe
if step_input.dagster_type_key != K8sOpMutatingOutput.__name__:
# dagster type names should be garunteed unique, so this should be safe.
job_def = step_context.job.get_definition()
input_def = job_def.get_op(step_context.step.node_handle).input_def_named(step_input.name)
input_mutation_enabled = False
input_metadata = input_def.metadata
if USER_DEFINED_INPUT_K8S_OP_MUTATION_KEY in input_metadata:
input_mutation_enabled = check.bool_elem(
input_metadata,
USER_DEFINED_INPUT_K8S_OP_MUTATION_KEY,
f"{USER_DEFINED_INPUT_K8S_OP_MUTATION_KEY} must be a bool",
)
input_mutation_enabled |= step_input.dagster_type_key == K8sOpMutatingWrapper.name()
if not input_mutation_enabled:
return None
event_log = step_context.instance.event_log_storage
if not isinstance(event_log, SqlEventLogStorage):
return step_context.log.error(
f"can only use executor with log type {type(SqlEventLogStorage)}"
)
source = step_input.source
if isinstance(source, FromPendingDynamicStepOutput):
upstream_output_handle = source.step_output_handle
Expand All @@ -115,111 +201,90 @@ def _get_input_metadata(
f"unable to consume {step_input.name}. FromStepOuput & FromPendingDynamicStepOutput sources supported, got {type(source)}"
)
if source.fan_in:
# this should never happen, but incase it does, log it
return step_context.log.error("fan in step input not supported")
job_def = step_context.job.get_definition()
# lie, cheat, and steal. Create step execution context ahead of time
step_run_ref = StepRunRef(
run_config=job_def.run_config or {},
dagster_run=step_context.dagster_run,
run_id=step_context.run_id,
step_key=source.step_output_handle.step_key,
retry_mode=step_context.retry_mode,
recon_job=step_context.reconstructable_job,
known_state=step_context.execution_plan.known_state,
)
step_exec_context = step_run_ref_to_step_context(step_run_ref, step_context.instance)
input_def = job_def.get_op(step_context.step.node_handle).input_def_named(step_input.name)
output_handle = source.step_output_handle
# get requisite input manager
if input_def.input_manager_key is not None:
manager_key = input_def.input_manager_key
input_manager = getattr(step_exec_context.resources, manager_key)
check.invariant(
isinstance(input_manager, InputManager),
f'Input "{input_def.name}" for step "{step_context.step.key}" is depending on '
f'the manager "{manager_key}" to load it, but it is not an InputManager. '
"Please ensure that the resource returned for resource key "
f'"{manager_key}" is an InputManager.',
upstream_record: Optional[StepOutputData] = None
has_more = True

while upstream_record is None and has_more:
record_conn = get_output_records_for_run_step(
event_log, step_context.run_id, output_handle.step_key
)
else:
manager_key = step_context.execution_plan.get_manager_key(
source.step_output_handle, job_def
has_more = record_conn.has_more
step_output_events = (
e.get_dagster_event().step_output_data
for e in map(lambda r: r.event_log_entry, record_conn.records)
if e.is_dagster_event and e.get_dagster_event().is_successful_output
)
input_manager = step_exec_context.get_io_manager(output_handle)
check.invariant(
isinstance(input_manager, IOManager),
f'Input "{input_def.name}" for step "{step_context.step.key}" is depending on '
f'the manager of upstream output "{output_handle.output_name}" from step '
f'"{output_handle.step_key}" to load it, but that manager is not an IOManager. '
"Please ensure that the resource returned for resource key "
f'"{manager_key}" is an IOManager.',
upstream_record = next(
(
output
for output in step_output_events
if output.step_output_handle == output_handle
),
None,
)
# load full input via input manager. DANGER, if value is excessively large this could have perf impacts on stephandler
input_context = source.get_load_context(step_exec_context, input_def, manager_key)
op_mutating_output: K8sOpMutatingOutput = input_manager.load_input(input_context)
step_context.instance.report_engine_event(
f'eagerly using input "{input_def.name}" from upstream op "{source.step_output_handle.step_key}" as config input',
step_context.dagster_run,
EngineEventData(
metadata={
USER_DEFINED_K8S_CONFIG_KEY: JsonMetadataValue(
data=op_mutating_output.k8s_config.to_dict()
)
}
),
step_key=step_context.step.key,
run_id=step_context.run_id,
if upstream_record is None:
return step_context.log.error(
f"unable to find output event for input {step_input.name}"
)
# should be guaranteed, so make ruff happy (TCH002)
check.inst(upstream_record, StepOutputData)
if USER_DEFINED_K8S_CONFIG_KEY not in upstream_record.metadata:
return step_context.log.warning(
f"upstream output {output_handle} has no metadata key {USER_DEFINED_K8S_CONFIG_KEY}"
)
k8s_config_md = upstream_record.metadata[USER_DEFINED_K8S_CONFIG_KEY]
if not isinstance(k8s_config_md, JsonMetadataValue):
return step_context.log.error(
f"user defined config metatdata need to be of type {type(JsonMetadataValue)}, got {type(k8s_config_md)}"
)
source_output_msg = f'using output metadata from output "{output_handle.output_name}"'
if output_handle.mapping_key:
source_output_msg += f' with mapping key "{output_handle.mapping_key}"'
step_context.log.info(
f'{source_output_msg} from step "{output_handle.step_key}" to mutate op k8s config'
)
return op_mutating_output.k8s_config
return UserDefinedDagsterK8sConfig.from_dict(k8s_config_md.data)

def _resolve_input_configs(
self, step_handler_context: StepHandlerContext
) -> Optional[K8sContainerContext]:
def _merge_input_configs(self, step_handler_context: StepHandlerContext) -> K8sContainerContext:
"""Fetch all the configured k8s metadata for op inputs."""
step_key = self._get_step_key(step_handler_context)
step_context = cast(
StepOrchestrationContext, step_handler_context.get_step_context(step_key)
)
container_context = None
k8s_context = super()._get_container_context(step_handler_context)
for step_input in step_context.step.step_inputs:
op_metadata_config = self._get_input_metadata(step_input, step_context)
if not op_metadata_config:
continue
k8s_context = K8sContainerContext(run_k8s_config=op_metadata_config)
if container_context is None:
container_context = k8s_context
else:
container_context.merge(k8s_context)
return container_context
k8s_context = k8s_context.merge(K8sContainerContext(run_k8s_config=op_metadata_config))

return k8s_context

def _get_container_context(
self, step_handler_context: StepHandlerContext
) -> K8sContainerContext:
# function should be safe to cache since it's idempotent
step_key = self._get_step_key(step_handler_context)
if step_key in self.container_ctx_cache:
return self.container_ctx_cache[step_key]
context = super()._get_container_context(step_handler_context)
self.container_ctx_cache[step_key] = context
step_context = step_handler_context.get_step_context(step_key)
if not self.op_mutation_enabled:
return self.container_ctx_cache[step_key]
# only use cache when op mutation is enabled. Fallback to K8sStephandler otherwise.
merged_input_k8s_configs = self._resolve_input_configs(step_handler_context)
if merged_input_k8s_configs:
self.container_ctx_cache[step_key] = self.container_ctx_cache[step_key].merge(
merged_input_k8s_configs
)
return self.container_ctx_cache[step_key]
step_context.log.warning("using op mutating executor with op mutation disabled")
return super()._get_container_context(step_handler_context)
step_handle = step_context.step.handle
if step_handle in self.container_ctx_cache:
return self.container_ctx_cache[step_handle]
self.container_ctx_cache[step_handle] = self._merge_input_configs(step_handler_context)
return self.container_ctx_cache[step_handle]

def terminate_step(self, step_handler_context: StepHandlerContext):
try:
yield from super().terminate_step(step_handler_context)
finally:
# pop cache to save mem for steps we won't visit again
step_key = self._get_step_key(step_handler_context)
step_context = step_handler_context.get_step_context(step_key)
# entry might not exist if op mutation is disabled
self.container_ctx_cache.pop(step_key, None)
self.container_ctx_cache.pop(step_context.step.handle, None)


@executor(
Expand Down
Loading

0 comments on commit 0648263

Please sign in to comment.