Skip to content

Feat: Create InMemoryTarget from TaskOnKart #441

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions gokart/in_memory/target.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any
from typing import Any, Optional

from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.target import TargetOnKart, TaskLockParams
Expand Down Expand Up @@ -41,5 +41,12 @@ def _path(self) -> str:
return self._data_key


def make_in_memory_target(target_key: str, task_lock_params: TaskLockParams) -> InMemoryTarget:
return InMemoryTarget(target_key, task_lock_params)
def _make_data_key(data_key: str, unique_id: Optional[str] = None) -> str:
if not unique_id:
return data_key
return data_key + '_' + unique_id


def make_in_memory_target(data_key: str, task_lock_params: TaskLockParams, unique_id: Optional[str] = None) -> InMemoryTarget:
_data_key = _make_data_key(data_key, unique_id)
return InMemoryTarget(_data_key, task_lock_params)
25 changes: 23 additions & 2 deletions gokart/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import gokart
import gokart.target
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock
from gokart.file_processor import FileProcessor
from gokart.in_memory.target import InMemoryTarget, make_in_memory_target
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.target import TargetOnKart
Expand Down Expand Up @@ -105,6 +106,9 @@ class TaskOnKart(luigi.Task, Generic[T]):
default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False
)
should_lock_run: bool = ExplicitBoolParameter(default=False, significant=False, description='Whether to use redis lock or not at task run.')
cache_in_memory_by_default: bool = ExplicitBoolParameter(
default=False, significant=False, description='If `True`, output is stored on a memory instead of files unless specified.'
)

@property
def priority(self):
Expand Down Expand Up @@ -134,11 +138,13 @@ def __init__(self, *args, **kwargs):
task_lock_params = make_task_lock_params_for_run(task_self=self)
self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore

self.make_default_target = self.make_target if not self.cache_in_memory_by_default else self.make_cache_target

def input(self) -> FlattenableItems[TargetOnKart]:
return super().input()

def output(self) -> FlattenableItems[TargetOnKart]:
return self.make_target()
return self.make_default_target()

def requires(self) -> FlattenableItems['TaskOnKart']:
tasks = self.make_task_instance_dictionary()
Expand Down Expand Up @@ -229,6 +235,21 @@ def make_target(self, relative_file_path: Optional[str] = None, use_unique_id: b
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
)

def make_cache_target(self, data_key: Optional[str] = None, use_unique_id: bool = True) -> InMemoryTarget:
_data_key = data_key if data_key else os.path.join(self.__module__.replace('.', '/'), type(self).__name__)
unique_id = self.make_unique_id() if use_unique_id else None
# TODO: combine with redis
task_lock_params = TaskLockParams(
redis_host=None,
redis_port=None,
redis_timeout=None,
redis_key='redis_key',
should_task_lock=False,
raise_task_lock_exception_on_collision=False,
lock_extend_seconds=-1,
)
return make_in_memory_target(_data_key, task_lock_params, unique_id)

def make_large_data_frame_target(self, relative_file_path: Optional[str] = None, use_unique_id: bool = True, max_byte=int(2**26)) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')
Expand Down
2 changes: 1 addition & 1 deletion test/in_memory/test_in_memory_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def task_lock_params(self) -> TaskLockParams:

@pytest.fixture
def target(self, task_lock_params: TaskLockParams) -> InMemoryTarget:
return make_in_memory_target(target_key='dummy_key', task_lock_params=task_lock_params)
return make_in_memory_target(data_key='dummy_key', task_lock_params=task_lock_params)

@pytest.fixture(autouse=True)
def clear_repo(self) -> None:
Expand Down
118 changes: 118 additions & 0 deletions test/in_memory/test_task_cached_in_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from typing import Optional, Type, Union

import luigi
import pytest

import gokart
from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget
from gokart.target import SingleFileTarget


class DummyTask(gokart.TaskOnKart):
task_namespace = __name__
param: str = luigi.Parameter()

def run(self):
self.dump(self.param)


class DummyTaskWithDependencies(gokart.TaskOnKart):
task_namespace = __name__
task: list[gokart.TaskOnKart[str]] = gokart.ListTaskInstanceParameter()

def run(self):
result = ','.join(self.load())
self.dump(result)


class DumpIntTask(gokart.TaskOnKart[int]):
task_namespace = __name__
value: int = luigi.IntParameter()

def run(self):
self.dump(self.value)


class AddTask(gokart.TaskOnKart[Union[int, float]]):
a: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
b: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()

def requires(self):
return dict(a=self.a, b=self.b)

def run(self):
a = self.load(self.a)
b = self.load(self.b)
self.dump(a + b)


class TestTaskOnKartWithCache:
@pytest.fixture(autouse=True)
def clear_repository(self) -> None:
InMemoryCacheRepository().clear()

@pytest.mark.parametrize('data_key', ['sample_key', None])
@pytest.mark.parametrize('use_unique_id', [True, False])
def test_key_identity(self, data_key: Optional[str], use_unique_id: bool):
task = DummyTask(param='param')
ext = '.pkl'
relative_file_path = data_key + ext if data_key else None
target = task.make_target(relative_file_path=relative_file_path, use_unique_id=use_unique_id)
cached_target = task.make_cache_target(data_key=data_key, use_unique_id=use_unique_id)

target_path = target.path().removeprefix(task.workspace_directory).removesuffix(ext).strip('/')
assert cached_target.path() == target_path

def test_make_cached_target(self):
task = DummyTask(param='param')
target = task.make_cache_target()
assert isinstance(target, InMemoryTarget)

@pytest.mark.parametrize(['cache_in_memory_by_default', 'target_type'], [[True, InMemoryTarget], [False, SingleFileTarget]])
def test_make_default_target(self, cache_in_memory_by_default: bool, target_type: Type[gokart.TaskOnKart]):
task = DummyTask(param='param', cache_in_memory_by_default=cache_in_memory_by_default)
target = task.output()
assert isinstance(target, target_type)

def test_complete_with_cache_in_memory_flag(self, tmpdir):
task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir)
assert not task.complete()
file_target = task.make_target()
file_target.dump('data')
assert not task.complete()
cache_target = task.make_cache_target()
cache_target.dump('data')
assert task.complete()

def test_complete_without_cache_in_memory_flag(self, tmpdir):
task = DummyTask(param='param', workspace_directory=tmpdir)
assert not task.complete()
cache_target = task.make_cache_target()
cache_target.dump('data')
assert not task.complete()
file_target = task.make_target()
file_target.dump('data')
assert task.complete()

def test_dump_with_cache_in_memory_flag(self, tmpdir):
task = DummyTask(param='param', cache_in_memory_by_default=True, workspace_directory=tmpdir)
file_target = task.make_target()
cache_target = task.make_cache_target()
task.dump('data')
assert not file_target.exists()
assert cache_target.exists()

def test_dump_without_cache_in_memory_flag(self, tmpdir):
task = DummyTask(param='param', workspace_directory=tmpdir)
file_target = task.make_target()
cache_target = task.make_cache_target()
task.dump('data')
assert file_target.exists()
assert not cache_target.exists()

def test_gokart_build(self):
task = AddTask(
a=DumpIntTask(value=2, cache_in_memory_by_default=True), b=DumpIntTask(value=3, cache_in_memory_by_default=True), cache_in_memory_by_default=True
)
output = gokart.build(task, reset_register=False)
assert output == 5
Loading