Skip to content

Commit

Permalink
chore: add tests for gokart parameters (#389)
Browse files Browse the repository at this point in the history
* chore: add tests for gokart parameters

* fix: handle parameters of `gokart.TaskOnKart`
  • Loading branch information
hiro-o918 authored Aug 1, 2024
1 parent 222c40a commit 418f36b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
18 changes: 12 additions & 6 deletions gokart/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,11 @@

class TaskOnKartPlugin(Plugin):
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
if 'gokart.task.luigi.Task' in fullname:
# gather attibutes from gokart.TaskOnKart
# the transformation does not affect because the class has `__init__` method
# The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
# the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
#
# NOTE: `gokart.task.luigi.Task` condition is required for the release of luigi versions without py.typed
if fullname in {'gokart.task.luigi.Task', 'luigi.task.Task'}:
return self._task_on_kart_class_maker_callback

sym = self.lookup_fully_qualified(fullname)
Expand Down Expand Up @@ -209,7 +211,6 @@ def transform(self) -> bool:
if ('__init__' not in info.names or info.names['__init__'].plugin_generated) and attributes:
args = [attr.to_argument(info, of='__init__') for attr in attributes]
add_method_to_class(self._api, self._cls, '__init__', args=args, return_type=NoneType())

info.metadata[METADATA_TAG] = {
'attributes': [attr.serialize() for attr in attributes],
}
Expand Down Expand Up @@ -330,6 +331,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
info=cls.info,
api=self._api,
)

return list(found_attrs.values())

def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Expression]]:
Expand Down Expand Up @@ -404,9 +406,13 @@ def is_parameter_call(expr: Expression) -> bool:
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return PARAMETER_FULLNAME_MATCHER.match(f'{callee.expr.name}.{callee.name}') is not None
if isinstance(type_info, TypeInfo) and PARAMETER_FULLNAME_MATCHER.match(type_info.fullname):
return True
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return False

if isinstance(type_info, TypeInfo):
return PARAMETER_FULLNAME_MATCHER.match(type_info.fullname) is not None
return False


Expand Down
15 changes: 12 additions & 3 deletions test/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo: int = luigi.IntParameter() # type: ignore
bar: str = luigi.Parameter() # type: ignore
baz: bool = gokart.ExplicitBoolParameter()
MyTask(foo=1, bar='bar')
# TaskOnKart parameters:
# - `complete_check_at_run`
MyTask(foo=1, bar='bar', baz=False, complete_check_at_run=False)
"""

with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
Expand All @@ -37,15 +41,20 @@ class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo: int = luigi.IntParameter() # type: ignore
bar: str = luigi.Parameter() # type: ignore
baz: bool = gokart.ExplicitBoolParameter()
# issue: foo is int
# not issue: bar is missing, because it can be set by config file.
MyTask(foo='1')
# TaskOnKart parameters:
# - `complete_check_at_run`
MyTask(foo='1', baz='not bool', complete_check_at_run='not bool')
"""

with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', result[0])
self.assertIn('Found 1 error in 1 file (checked 1 source file)', result[0])
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', result[0])
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', result[0])

0 comments on commit 418f36b

Please sign in to comment.