|
| 1 | +import ast |
| 2 | +import sys |
| 3 | +from typing import Optional, Type, TypeVar |
| 4 | + |
| 5 | +T = TypeVar('T', covariant=True) |
| 6 | + |
| 7 | + |
| 8 | +def find(nodes: list, tpe: Type[T], attr: str, name: str) -> Optional[T]: |
| 9 | + for n in nodes: |
| 10 | + if type(n) is tpe and getattr(n, attr) == name: |
| 11 | + return n |
| 12 | + return None |
| 13 | + |
| 14 | + |
| 15 | +def check(file: str, predictor: str) -> None: |
| 16 | + with open(file, 'r') as f: |
| 17 | + root = ast.parse(f.read()) |
| 18 | + |
| 19 | + p = find(root.body, ast.ClassDef, 'name', predictor) |
| 20 | + if p is None: |
| 21 | + return |
| 22 | + fn = find(p.body, ast.FunctionDef, 'name', 'predict') |
| 23 | + if fn is None: |
| 24 | + fn = find(p.body, ast.AsyncFunctionDef, 'name', 'predict') # type: ignore |
| 25 | + args_and_defauts = zip(fn.args.args[-len(fn.args.defaults) :], fn.args.defaults) # type: ignore |
| 26 | + for a, d in args_and_defauts: |
| 27 | + if type(a.annotation) is not ast.Name: |
| 28 | + continue |
| 29 | + if type(d) is not ast.Call or d.func.id != 'Input': # type: ignore |
| 30 | + continue |
| 31 | + v = find(d.keywords, ast.keyword, 'arg', 'default') |
| 32 | + if v is None or type(v.value) is not ast.Constant: |
| 33 | + continue |
| 34 | + if v.value.value is None: |
| 35 | + pos = f'{file}:{a.lineno}:{a.col_offset}' |
| 36 | + print( |
| 37 | + f'{pos}: Input(default=None, ...) without Optional type hint is ambiguous and deprecated' |
| 38 | + ) |
| 39 | + src = f'{ast.unparse(a)} = {ast.unparse(d)}' |
| 40 | + d.keywords = [k for k in d.keywords if k.arg != 'default'] |
| 41 | + dst = f'{a.arg}: Optional[{a.annotation.id}] = {ast.unparse(d)}' |
| 42 | + print(f'- {src}') |
| 43 | + print(f'+ {dst}') |
| 44 | + |
| 45 | + |
| 46 | +check(sys.argv[1], sys.argv[2]) |
0 commit comments