diff --git a/examples/logging.ini b/examples/logging.ini new file mode 100644 index 00000000..4710b6ff --- /dev/null +++ b/examples/logging.ini @@ -0,0 +1,40 @@ +[loggers] +keys=root,luigi,luigi-interface,gokart + +[handlers] +keys=stderrHandler + +[formatters] +keys=simpleFormatter + +[logger_root] +level=INFO +handlers=stderrHandler + +[logger_gokart] +level=INFO +handlers=stderrHandler +qualname=gokart +propagate=0 + +[logger_luigi] +level=INFO +handlers=stderrHandler +qualname=luigi +propagate=0 + +[logger_luigi-interface] +level=INFO +handlers=stderrHandler +qualname=luigi-interface +propagate=0 + +[handler_stderrHandler] +class=StreamHandler +formatter=simpleFormatter +args=(sys.stdout,) + +[formatter_simpleFormatter] +format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s +datefmt=%Y/%m/%d %H:%M:%S +class=logging.Formatter diff --git a/examples/param.ini b/examples/param.ini new file mode 100644 index 00000000..24a126ce --- /dev/null +++ b/examples/param.ini @@ -0,0 +1,6 @@ +[TaskOnKart] +workspace_directory=./resource +local_temporary_directory=./resource/tmp + +[core] +logging_conf_file=logging.ini \ No newline at end of file diff --git a/gokart/gcs_obj_metadata_client.py b/gokart/gcs_obj_metadata_client.py index 5b488693..954d5dd8 100644 --- a/gokart/gcs_obj_metadata_client.py +++ b/gokart/gcs_obj_metadata_client.py @@ -1,7 +1,9 @@ from __future__ import annotations import copy +import json import re +from collections.abc import Iterable from logging import getLogger from typing import Any from urllib.parse import urlsplit @@ -9,6 +11,8 @@ from googleapiclient.model import makepatch from gokart.gcs_config import GCSConfig +from gokart.required_task_output import RequiredTaskOutput +from gokart.utils import FlattenableItems logger = getLogger(__name__) @@ -21,7 +25,7 @@ class GCSObjectMetadataClient: @staticmethod def _is_log_related_path(path: str) -> bool: - return re.match(r'^log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None + return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None # This is the copied method of luigi.gcs._path_to_bucket_and_key(path). @staticmethod @@ -32,7 +36,12 @@ def _path_to_bucket_and_key(path: str) -> tuple[str, str]: return netloc, path_without_initial_slash @staticmethod - def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: + def add_task_state_labels( + path: str, + task_params: dict[str, str] | None = None, + custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, + ) -> None: if GCSObjectMetadataClient._is_log_related_path(path): return # In gokart/object_storage.get_time_stamp, could find same call. @@ -42,20 +51,18 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, if _response is None: logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.') return - response: dict[str, Any] = dict(_response) original_metadata: dict[Any, Any] = {} if 'metadata' in response.keys(): _metadata = response.get('metadata') if _metadata is not None: original_metadata = dict(_metadata) - patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata( copy.deepcopy(original_metadata), task_params, custom_labels, + required_task_outputs if required_task_outputs else None, ) - if original_metadata != patched_metadata: # If we use update api, existing object metadata are removed, so should use patch api. # See the official document descriptions. @@ -71,7 +78,6 @@ def add_task_state_labels(path: str, task_params: dict[str, str] | None = None, ) .execute() ) - if update_response is None: logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.') @@ -84,13 +90,13 @@ def _get_patched_obj_metadata( metadata: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> dict | Any: # If metadata from response when getting bucket and object information is not dictionary, # something wrong might be happened, so return original metadata, no patched. if not isinstance(metadata, dict): logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.') return metadata - if not task_params and not custom_labels: return metadata # Maximum size of metadata for each object is 8 KiB. @@ -101,23 +107,49 @@ def _get_patched_obj_metadata( # However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters. # Instead, users are expected to search using the labels they provided. # Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence. - _merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_task_params_labels, normalized_custom_labels) + normalized_labels = ( + [normalized_custom_labels, normalized_task_params_labels] + if not required_task_outputs + else [ + normalized_custom_labels, + normalized_task_params_labels, + {'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))}, + ] + ) + _merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels) return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels) + @staticmethod + def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]: + def _iterable_flatten(nested_list: Iterable) -> list[str]: + flattened_list: list[str] = [] + for item in nested_list: + if isinstance(item, Iterable): + flattened_list.extend(_iterable_flatten(item)) + else: + flattened_list.append(item) + return flattened_list + + if isinstance(required_task_outputs, dict): + return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()} + if isinstance(required_task_outputs, Iterable): + return _iterable_flatten([GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]) + return [required_task_outputs.serialize()] + @staticmethod def _merge_custom_labels_and_task_params_labels( - normalized_task_params: dict[str, str], - normalized_custom_labels: dict[str, Any], + normalized_labels_list: list[dict[str, Any]], ) -> dict[str, str]: - merged_labels = copy.deepcopy(normalized_custom_labels) - for label_name, label_value in normalized_task_params.items(): - if len(label_value) == 0: - logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.') - continue - if label_name in merged_labels.keys(): - logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.') - continue - merged_labels[label_name] = label_value + merged_labels: dict[str, str] = {} + for normalized_label in normalized_labels_list[:]: + for label_name, label_value in normalized_label.items(): + if len(label_value) == 0: + logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.') + continue + if label_name in merged_labels.keys(): + logger.warning(f'label_name={label_name} is already seen. So skip to add as a metadata.') + continue + merged_labels[label_name] = label_value return merged_labels # Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB. @@ -132,10 +164,8 @@ def _get_label_size(label_name: str, label_value: str) -> int: 8 * 1024, sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()), ) - if current_total_metadata_size <= max_gcs_metadata_size: return labels - for label_name, label_value in reversed(labels.items()): size = _get_label_size(label_name, label_value) del labels[label_name] diff --git a/gokart/in_memory/target.py b/gokart/in_memory/target.py index f17c1a4c..066fae9f 100644 --- a/gokart/in_memory/target.py +++ b/gokart/in_memory/target.py @@ -4,7 +4,9 @@ from typing import Any from gokart.in_memory.repository import InMemoryCacheRepository +from gokart.required_task_output import RequiredTaskOutput from gokart.target import TargetOnKart, TaskLockParams +from gokart.utils import FlattenableItems _repository = InMemoryCacheRepository() @@ -26,7 +28,13 @@ def _get_task_lock_params(self) -> TaskLockParams: def _load(self) -> Any: return _repository.get_value(self._data_key) - def _dump(self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: + def _dump( + self, + obj: Any, + task_params: dict[str, str] | None = None, + custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, + ) -> None: return _repository.set_value(self._data_key, obj) def _remove(self) -> None: diff --git a/gokart/required_task_output.py b/gokart/required_task_output.py new file mode 100644 index 00000000..b7c26fcc --- /dev/null +++ b/gokart/required_task_output.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class RequiredTaskOutput: + task_name: str + output_path: str + + def serialize(self) -> dict[str, str]: + return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path} diff --git a/gokart/target.py b/gokart/target.py index b4b5d3a0..fcda99e9 100644 --- a/gokart/target.py +++ b/gokart/target.py @@ -18,6 +18,8 @@ from gokart.file_processor import FileProcessor, make_file_processor from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient from gokart.object_storage import ObjectStorage +from gokart.required_task_output import RequiredTaskOutput +from gokart.utils import FlattenableItems from gokart.zip_client_util import make_zip_client logger = getLogger(__name__) @@ -30,13 +32,23 @@ def exists(self) -> bool: def load(self) -> Any: return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())() - def dump(self, obj, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: + def dump( + self, + obj, + lock_at_dump: bool = True, + task_params: dict[str, str] | None = None, + custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, + ) -> None: if lock_at_dump: wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)( - obj=obj, task_params=task_params, custom_labels=custom_labels + obj=obj, + task_params=task_params, + custom_labels=custom_labels, + required_task_outputs=required_task_outputs, ) else: - self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels) + self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs) def remove(self) -> None: if self.exists(): @@ -61,7 +73,13 @@ def _load(self) -> Any: pass @abstractmethod - def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: + def _dump( + self, + obj, + task_params: dict[str, str] | None = None, + custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, + ) -> None: pass @abstractmethod @@ -98,11 +116,19 @@ def _load(self) -> Any: with self._target.open('r') as f: return self._processor.load(f) - def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: + def _dump( + self, + obj, + task_params: dict[str, str] | None = None, + custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, + ) -> None: with self._target.open('w') as f: self._processor.dump(obj, f) if self.path().startswith('gs://'): - GCSObjectMetadataClient.add_task_state_labels(path=self.path(), task_params=task_params, custom_labels=custom_labels) + GCSObjectMetadataClient.add_task_state_labels( + path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs + ) def _remove(self) -> None: self._target.remove() @@ -142,10 +168,18 @@ def _load(self) -> Any: self._remove_temporary_directory() return model - def _dump(self, obj, task_params: dict[str, str] | None = None, custom_labels: dict[str, Any] | None = None) -> None: + def _dump( + self, + obj, + task_params: dict[str, str] | None = None, + custom_labels: dict[str, Any] | None = None, + required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, + ) -> None: self._make_temporary_directory() self._save_function(obj, self._model_path()) - make_target(self._load_function_path()).dump(self._load_function, task_params=task_params) + make_target(self._load_function_path()).dump( + self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs + ) self._zip_client.make_archive() self._remove_temporary_directory() diff --git a/gokart/task.py b/gokart/task.py index 5a671c2d..a46ac0d4 100644 --- a/gokart/task.py +++ b/gokart/task.py @@ -5,12 +5,21 @@ import inspect import os import random +import sys import types from collections.abc import Generator, Iterable from importlib import import_module from logging import getLogger from typing import Any, Callable, Generic, TypeVar, overload +from gokart.required_task_output import RequiredTaskOutput +from gokart.utils import map_flattenable_items + +if sys.version_info < (3, 13): + pass +else: + pass + import luigi import pandas as pd from luigi.parameter import ParameterVisibility @@ -337,11 +346,17 @@ def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels if isinstance(obj, pd.DataFrame) and obj.empty: raise EmptyDumpError() + required_task_outputs = map_flattenable_items( + lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()), + self.requires(), + ) + self._get_output_target(target).dump( obj, lock_at_dump=self._lock_at_dump, task_params=super().to_str_params(only_significant=True, only_public=True), custom_labels=custom_labels, + required_task_outputs=required_task_outputs, ) @staticmethod diff --git a/gokart/utils.py b/gokart/utils.py index df8f53fa..06b84f4c 100644 --- a/gokart/utils.py +++ b/gokart/utils.py @@ -4,7 +4,7 @@ import sys from collections.abc import Iterable from io import BytesIO -from typing import Any, Protocol, TypeVar, Union +from typing import Any, Callable, Protocol, TypeVar, Union import dill import luigi @@ -72,6 +72,21 @@ def flatten(targets: FlattenableItems[T]) -> list[T]: return flat +K = TypeVar('K') + + +def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]: + if isinstance(items, dict): + return {k: map_flattenable_items(func, v) for k, v in items.items()} + if isinstance(items, tuple): + return tuple(map_flattenable_items(func, i) for i in items) + if isinstance(items, str): + return func(items) # type: ignore + if isinstance(items, Iterable): + return [map_flattenable_items(func, i) for i in items] + return func(items) + + def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any: """Load binary dumped by dill with pandas backward compatibility. pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle. diff --git a/test/test_gcs_obj_metadata_client.py b/test/test_gcs_obj_metadata_client.py index 2fe37f01..9edd08ab 100644 --- a/test/test_gcs_obj_metadata_client.py +++ b/test/test_gcs_obj_metadata_client.py @@ -122,7 +122,10 @@ def test_mock_target_on_kart(self, mock_get_output_target): task = _DummyTaskOnKart() task.dump({'key': 'value'}, mock_target) - mock_target.dump.assert_called_once_with({'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None) + + mock_target.dump.assert_called_once_with( + {'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None, required_task_outputs=[] + ) if __name__ == '__main__': diff --git a/test/test_utils.py b/test/test_utils.py index 1c4b97e7..9b49d330 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,6 @@ import unittest -from gokart.utils import flatten +from gokart.utils import flatten, map_flattenable_items class TestFlatten(unittest.TestCase): @@ -18,3 +18,19 @@ def test_flatten_int(self): def test_flatten_none(self): self.assertEqual(flatten(None), []) + + +class TestMapFlatten(unittest.TestCase): + def test_map_flattenable_items(self): + self.assertEqual(map_flattenable_items(lambda x: str(x), {'a': 1, 'b': 2}), {'a': '1', 'b': '2'}) + self.assertEqual( + map_flattenable_items(lambda x: str(x), (1, 2, 3, (4, 5, (6, 7, {'a': (8, 9, 0)})))), + ('1', '2', '3', ('4', '5', ('6', '7', {'a': ('8', '9', '0')}))), + ) + self.assertEqual( + map_flattenable_items( + lambda x: str(x), + {'a': [1, 2, 3, '4'], 'b': {'c': True, 'd': {'e': 5}}}, + ), + {'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}}, + )