Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support functionalities to enhance task traceability with metadata for dependency search. #450

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions examples/logging.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
[loggers]
keys=root,luigi,luigi-interface,gokart

[handlers]
keys=stderrHandler

[formatters]
keys=simpleFormatter

[logger_root]
level=INFO
handlers=stderrHandler

[logger_gokart]
level=INFO
handlers=stderrHandler
qualname=gokart
propagate=0

[logger_luigi]
level=INFO
handlers=stderrHandler
qualname=luigi
propagate=0

[logger_luigi-interface]
level=INFO
handlers=stderrHandler
qualname=luigi-interface
propagate=0

[handler_stderrHandler]
class=StreamHandler
formatter=simpleFormatter
args=(sys.stdout,)

[formatter_simpleFormatter]
format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s
datefmt=%Y/%m/%d %H:%M:%S
class=logging.Formatter
6 changes: 6 additions & 0 deletions examples/param.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[TaskOnKart]
workspace_directory=./resource
local_temporary_directory=./resource/tmp

[core]
logging_conf_file=logging.ini
72 changes: 51 additions & 21 deletions gokart/gcs_obj_metadata_client.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import copy
import json
import re
from collections.abc import Iterable
from logging import getLogger
from typing import Any
from urllib.parse import urlsplit

from googleapiclient.model import makepatch

from gokart.gcs_config import GCSConfig
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems

logger = getLogger(__name__)

Expand All @@ -21,7 +25,7 @@ class GCSObjectMetadataClient:

@staticmethod
def _is_log_related_path(path: str) -> bool:
return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None

# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
@staticmethod
Expand All @@ -32,7 +36,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
return netloc, path_without_initial_slash

@staticmethod
def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def add_task_state_labels(
path: str,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
if GCSObjectMetadataClient._is_log_related_path(path):
return
# In gokart/object_storage.get_time_stamp, could find same call.
Expand All @@ -42,20 +51,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
if _response is None:
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
return

response: dict[str, Any] = dict(_response)
original_metadata: dict[Any, Any] = {}
if 'metadata' in response.keys():
_metadata = response.get('metadata')
if _metadata is not None:
original_metadata = dict(_metadata)

patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
copy.deepcopy(original_metadata),
task_params,
custom_labels,
required_task_outputs if required_task_outputs else None,
)

if original_metadata != patched_metadata:
# If we use update api, existing object metadata are removed, so should use patch api.
# See the official document descriptions.
Expand All @@ -71,7 +78,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None,
)
.execute()
)

if update_response is None:
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')

Expand All @@ -84,13 +90,13 @@ def _get_patched_obj_metadata(
metadata: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> dict | Any:
# If metadata from response when getting bucket and object information is not dictionary,
# something wrong might be happened, so return original metadata, no patched.
if not isinstance(metadata, dict):
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
return metadata

if not task_params and not custom_labels:
return metadata
# Maximum size of metadata for each object is 8 KiB.
Expand All @@ -101,23 +107,49 @@ def _get_patched_obj_metadata(
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
# Instead, users are expected to search using the labels they provided.
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels)
normalized_labels = (
[normalized_custom_labels, normalized_task_params_labels]
if not required_task_outputs
else [
normalized_custom_labels,
normalized_task_params_labels,
{'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))},
]
)
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels)
return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels)

@staticmethod
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
def _iterable_flatten(nested_list: Iterable) -> list[str]:
flattened_list: list[str] = []
for item in nested_list:
if isinstance(item, Iterable):
flattened_list.extend(_iterable_flatten(item))
else:
flattened_list.append(item)
return flattened_list

if isinstance(required_task_outputs, dict):
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
if isinstance(required_task_outputs, Iterable):
return _iterable_flatten([GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs])
return [required_task_outputs.serialize()]

@staticmethod
def _merge_custom_labels_and_task_params_labels(
normalized_task_params: dict[str, str],
normalized_custom_labels: dict[str, Any],
normalized_labels_list: list[dict[str, Any]],
) -> dict[str, str]:
merged_labels = copy.deepcopy(normalized_custom_labels)
for label_name, label_value in normalized_task_params.items():
if len(label_value) == 0:
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in merged_labels.keys():
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
continue
merged_labels[label_name] = label_value
merged_labels: dict[str, str] = {}
for normalized_label in normalized_labels_list[:]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for normalized_label in normalized_labels_list[:]:
for normalized_label in normalized_labels_list:

for label_name, label_value in normalized_label.items():
if len(label_value) == 0:
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in merged_labels.keys():
logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.')
continue
merged_labels[label_name] = label_value
return merged_labels

# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
Expand All @@ -132,10 +164,8 @@ def _get_label_size(label_name: str, label_value: str) -> int:
8 * 1024,
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
)

if current_total_metadata_size <= max_gcs_metadata_size:
return labels

for label_name, label_value in reversed(labels.items()):
size = _get_label_size(label_name, label_value)
del labels[label_name]
Expand Down
10 changes: 9 additions & 1 deletion gokart/in_memory/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Any

from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart, TaskLockParams
from gokart.utils import FlattenableItems

_repository = InMemoryCacheRepository()

Expand All @@ -26,7 +28,13 @@ def _get_task_lock_params(self) -> TaskLockParams:
def _load(self) -> Any:
return _repository.get_value(self._data_key)

def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
return _repository.set_value(self._data_key, obj)

def _remove(self) -> None:
Expand Down
10 changes: 10 additions & 0 deletions gokart/required_task_output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from dataclasses import dataclass


@dataclass
class RequiredTaskOutput:
task_name: str
output_path: str

def serialize(self) -> dict[str, str]:
return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path}
50 changes: 42 additions & 8 deletions gokart/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
from gokart.object_storage import ObjectStorage
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems
from gokart.zip_client_util import make_zip_client

logger = getLogger(__name__)
Expand All @@ -30,13 +32,23 @@ def exists(self) -> bool:
def load(self) -> Any:
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()

def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def dump(
self,
obj,
lock_at_dump: bool = True,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
if lock_at_dump:
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
obj=obj, task_params=task_params, custom_labels=custom_labels
obj=obj,
task_params=task_params,
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)
else:
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels)
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs)

def remove(self) -> None:
if self.exists():
Expand All @@ -61,7 +73,13 @@ def _load(self) -> Any:
pass

@abstractmethod
def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
pass

@abstractmethod
Expand Down Expand Up @@ -98,11 +116,19 @@ def _load(self) -> Any:
with self._target.open('r') as f:
return self._processor.load(f)

def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
with self._target.open('w') as f:
self._processor.dump(obj, f)
if self.path().startswith('gs://'):
GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels)
GCSObjectMetadataClient.add_task_state_labels(
path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)

def _remove(self) -> None:
self._target.remove()
Expand Down Expand Up @@ -142,10 +168,18 @@ def _load(self) -> Any:
self._remove_temporary_directory()
return model

def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None:
def _dump(
self,
obj,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, Any] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
self._make_temporary_directory()
self._save_function(obj, self._model_path())
make_target(self._load_function_path()).dump(self._load_function, task_params=task_params)
make_target(self._load_function_path()).dump(
self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)
self._zip_client.make_archive()
self._remove_temporary_directory()

Expand Down
15 changes: 15 additions & 0 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@
import inspect
import os
import random
import sys
import types
from collections.abc import Generator, Iterable
from importlib import import_module
from logging import getLogger
from typing import Any, Callable, Generic, TypeVar, overload

from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import map_flattenable_items

if sys.version_info < (3, 13):
pass
else:
pass

import luigi
import pandas as pd
from luigi.parameter import ParameterVisibility
Expand Down Expand Up @@ -337,11 +346,17 @@ def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels
if isinstance(obj, pd.DataFrame) and obj.empty:
raise EmptyDumpError()

required_task_outputs = map_flattenable_items(
lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()),
self.requires(),
)

self._get_output_target(target).dump(
obj,
lock_at_dump=self._lock_at_dump,
task_params=super().to_str_params(only_significant=True, only_public=True),
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)

@staticmethod
Expand Down
Loading