Skip to content

Commit 8ef337e

Browse files
committed
Add check for legacy default=None optional input
1 parent 3ad9e13 commit 8ef337e

File tree

2 files changed

+62
-0
lines changed

2 files changed

+62
-0
lines changed

python/coglet/check.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import ast
2+
import sys
3+
from typing import Optional, Type, TypeVar
4+
5+
T = TypeVar('T', covariant=True)
6+
7+
8+
def find(nodes: list, tpe: Type[T], attr: str, name: str) -> Optional[T]:
9+
for n in nodes:
10+
if type(n) is tpe and getattr(n, attr) == name:
11+
return n
12+
return None
13+
14+
15+
def check(file: str, predictor: str) -> None:
16+
with open(file, 'r') as f:
17+
root = ast.parse(f.read())
18+
19+
p = find(root.body, ast.ClassDef, 'name', predictor)
20+
if p is None:
21+
return
22+
fn = find(p.body, ast.FunctionDef, 'name', 'predict')
23+
if fn is None:
24+
fn = find(p.body, ast.AsyncFunctionDef, 'name', 'predict') # type: ignore
25+
args_and_defauts = zip(fn.args.args[-len(fn.args.defaults) :], fn.args.defaults) # type: ignore
26+
for a, d in args_and_defauts:
27+
if type(a.annotation) is not ast.Name:
28+
continue
29+
if type(d) is not ast.Call or d.func.id != 'Input': # type: ignore
30+
continue
31+
v = find(d.keywords, ast.keyword, 'arg', 'default')
32+
if v is None or type(v.value) is not ast.Constant:
33+
continue
34+
if v.value.value is None:
35+
pos = f'{file}:{a.lineno}:{a.col_offset}'
36+
print(
37+
f'{pos}: Input(default=None, ...) without Optional type hint is ambiguous and deprecated'
38+
)
39+
src = f'{ast.unparse(a)} = {ast.unparse(d)}'
40+
d.keywords = [k for k in d.keywords if k.arg != 'default']
41+
dst = f'{a.arg}: Optional[{a.annotation.id}] = {ast.unparse(d)}'
42+
print(f'- {src}')
43+
print(f'+ {dst}')
44+
45+
46+
check(sys.argv[1], sys.argv[2])

python/tests/cases/legacy_optional.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from typing import Optional
2+
3+
from cog import BasePredictor, Input
4+
5+
6+
class Predictor(BasePredictor):
7+
def predict(
8+
self,
9+
s1: str,
10+
s2: Optional[str],
11+
s3: str = Input(),
12+
s4: Optional[str] = Input(),
13+
s5: str = Input(default=None, description='x'),
14+
s6: Optional[str] = Input(default=None),
15+
) -> str:
16+
return f'{s1}:{s2}:{s3}:{s4}:{s5}:{s6}'

0 commit comments

Comments
 (0)