Skip to content

Commit cb9f5ad

Browse files
Merge pull request #78 from data-apis/record-returned-value
Record function returned types
2 parents 8694ef8 + 914972c commit cb9f5ad

File tree

5 files changed

+143
-42
lines changed

5 files changed

+143
-42
lines changed

record_api/apis.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ class Signature(BaseModel):
374374
var_kw: typing.Optional[typing.Tuple[str, Type]] = None
375375

376376
metadata: typing.Dict[str, int] = pydantic.Field(default_factory=dict)
377+
return_type: OutputType = None
377378

378379
@pydantic.validator("pos_only_required")
379380
@classmethod
@@ -408,6 +409,13 @@ def validate_keys_unique(cls, values) -> None:
408409
raise ValueError(repr(all_keys))
409410
return values
410411

412+
@property
413+
def return_type_annotation(self) -> typing.Optional[cst.Annotation]:
414+
return_type_annotation = None
415+
if self.return_type:
416+
return_type_annotation = cst.Annotation(self.return_type.annotation)
417+
return return_type_annotation
418+
411419
def function_def(
412420
self,
413421
name: str,
@@ -427,6 +435,7 @@ def function_def(
427435
[cst.SimpleStatementLine([s]) for s in self.body(indent)]
428436
),
429437
decorators,
438+
self.return_type_annotation
430439
)
431440

432441
def body(self, indent: int) -> typing.Iterable[cst.BaseSmallStatement]:
@@ -508,13 +517,17 @@ def initial_args(self) -> typing.Iterator[Type]:
508517

509518
@classmethod
510519
def from_params(
511-
cls, args: typing.List[object] = [], kwargs: typing.Dict[str, object] = {}
520+
cls,
521+
args: typing.List[object] = [],
522+
kwargs: typing.Dict[str, object] = {},
523+
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = {},
512524
) -> Signature:
513525
# If we don't know what the args/kwargs are, assume the args are positional only
514526
# and the kwargs and keyword only
515527
return cls(
516528
pos_only_required={f"_{i}": create_type(v) for i, v in enumerate(args)},
517529
kw_only_required={k: create_type(v) for k, v in kwargs.items()},
530+
return_type=create_type(return_type) if return_type else None
518531
)
519532

520533
@classmethod
@@ -525,6 +538,7 @@ def from_bound_params(
525538
var_pos: typing.Optional[typing.Tuple[str, typing.List[object]]] = None,
526539
kw_only: typing.Dict[str, object] = {},
527540
var_kw: typing.Optional[typing.Tuple[str, typing.Dict[str, object]]] = None,
541+
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = {},
528542
) -> Signature:
529543
return cls(
530544
pos_only_required={k: create_type(v) for k, v in pos_only},
@@ -538,6 +552,7 @@ def from_bound_params(
538552
if var_kw
539553
else None
540554
),
555+
return_type=create_type(return_type) if return_type else None
541556
)
542557

543558
def content_equal(self, other: Signature) -> bool:
@@ -557,6 +572,7 @@ def __ior__(self, other: Signature) -> Signature:
557572
self._copy_var_pos(other)
558573
self._copy_kw_only(other)
559574
self._copy_var_kw(other)
575+
self._copy_return_type(other)
560576

561577
update_add(self.metadata, other.metadata)
562578
self._trim_positional_only_args()
@@ -718,6 +734,13 @@ def _copy_var_kw(self, other: Signature) -> None:
718734
else (self.var_kw or other.var_kw)
719735
)
720736

737+
def _copy_return_type(self, other: Signature) -> None:
738+
self.return_type = (
739+
unify((self.return_type, other.return_type,))
740+
if self.return_type and other.return_type
741+
else (self.return_type or other.return_type)
742+
)
743+
721744

722745
API.update_forward_refs()
723746
Class.update_forward_refs()

record_api/core.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@
2929
context_manager: Optional[ContextManager] = None
3030
write_line: Optional[Callable[[dict], None]] = None
3131

32+
FUNCTION_CALL_OP_NAMES = {
33+
"CALL_METHOD",
34+
"CALL_FUNCTION",
35+
"CALL_FUNCTION_KW",
36+
"CALL_FUNCTION_EX",
37+
"LOAD_ATTR",
38+
"BINARY_SUBSCR",
39+
}
40+
3241

3342
def get_tracer() -> Tracer:
3443
global TRACER
@@ -273,6 +282,7 @@ def log_call(
273282
fn: Callable,
274283
args: Iterable = (),
275284
kwargs: Mapping[str, Any] = {},
285+
return_type: Any = None,
276286
) -> None:
277287
bound = Bound.create(fn, args, kwargs)
278288
line: Dict = {"location": location, "function": preprocess(fn)}
@@ -284,6 +294,8 @@ def log_call(
284294
line["params"]["kwargs"] = {k: preprocess(v) for k, v in kwargs.items()}
285295
else:
286296
line["bound_params"] = bound.as_dict()
297+
if return_type:
298+
line['return_type'] = return_type
287299
assert write_line
288300
write_line(line)
289301

@@ -295,11 +307,16 @@ class Stack:
295307
NULL: ClassVar[object] = object()
296308
current_i: int = dataclasses.field(init=False, default=0)
297309
opcode: int = dataclasses.field(init=False)
310+
previous_stack: Optional[Stack] = None
311+
log_call_args: Tuple = ()
298312

299313
def __post_init__(self):
300314
self.op_stack = get_stack.OpStack(self.frame)
301315
self.opcode = self.frame.f_code.co_code[self.frame.f_lasti]
302316

317+
if self.previous_stack and self.previous_stack.previous_stack:
318+
del self.previous_stack.previous_stack
319+
303320
@property
304321
def oparg(self):
305322
# sort of replicates logic in dis._unpack_opargs but doesn't account for extended
@@ -360,14 +377,24 @@ def pop_n(self, n: int) -> List:
360377
return l
361378

362379
def process(
363-
self, keyed_args: Tuple, fn: Callable, args: Iterable, kwargs: Mapping = {}
380+
self,
381+
keyed_args: Tuple,
382+
fn: Callable,
383+
args: Iterable,
384+
kwargs: Mapping = {},
385+
delay: bool = False
364386
) -> None:
365-
# Note: This take args as an iterable, instead of as a varargs, so that if we don't trace we don't have to expand the iterable
387+
388+
# Note: This take args as an iterable, instead of as a varargs, so that if
389+
# we don't trace we don't have to expand the iterable
366390
if self.tracer.should_trace(*keyed_args):
367391
filename = self.frame.f_code.co_filename
368392
line = self.frame.f_lineno
369393
# Don't pass kwargs if not used, so we can more easily test mock calls
370-
log_call(f"{filename}:{line}", fn, tuple(args), *((kwargs,) if kwargs else ()))
394+
if not delay:
395+
log_call(f"{filename}:{line}", fn, tuple(args), *((kwargs,) if kwargs else ()))
396+
else:
397+
self.log_call_args = (filename, line, fn, tuple(args), kwargs)
371398

372399
def __call__(self) -> None:
373400
"""
@@ -383,14 +410,34 @@ def __call__(self) -> None:
383410
(self.TOS, self.TOS1), BINARY_OPS[opname], (self.TOS1, self.TOS)
384411
)
385412

413+
if self.previous_stack and self.previous_stack.opname in FUNCTION_CALL_OP_NAMES:
414+
self.log_called_method()
415+
386416
method_name = f"op_{opname}"
387417
if hasattr(self, method_name):
388418
getattr(self, method_name)()
389419
return None
390420

421+
def log_called_method(self):
422+
if self.previous_stack.log_call_args:
423+
tos = self.TOS
424+
if type(tos) is type and issubclass(tos, Exception):
425+
# Don't record exception
426+
return
427+
return_type = type(tos) if type(tos) != type else tos
428+
filename, line, fn, args, *kwargs = self.previous_stack.log_call_args
429+
kwargs = kwargs[0] if kwargs else {}
430+
log_call(
431+
f"{filename}:{line}",
432+
fn,
433+
tuple(args),
434+
*((kwargs,) if kwargs else ()),
435+
return_type=return_type,
436+
)
437+
391438
# special case subscr b/c we only check first arg, not both
392439
def op_BINARY_SUBSCR(self):
393-
self.process((self.TOS1,), op.getitem, (self.TOS1, self.TOS))
440+
self.process((self.TOS1,), op.getitem, (self.TOS1, self.TOS), delay=True)
394441

395442
def op_STORE_SUBSCR(self):
396443
self.process((self.TOS1,), op.setitem, (self.TOS1, self.TOS, self.TOS2))
@@ -399,7 +446,7 @@ def op_DELETE_SUBSCR(self):
399446
self.process((self.TOS1,), op.delitem, (self.TOS1, self.TOS))
400447

401448
def op_LOAD_ATTR(self):
402-
self.process((self.TOS,), getattr, (self.TOS, self.opvalname))
449+
self.process((self.TOS,), getattr, (self.TOS, self.opvalname), delay=True)
403450

404451
def op_STORE_ATTR(self):
405452
self.process((self.TOS,), setattr, (self.TOS, self.opvalname, self.TOS1))
@@ -458,7 +505,7 @@ def op_COMPARE_OP(self):
458505
def op_CALL_FUNCTION(self):
459506
args = self.pop_n(self.oparg)
460507
fn = self.pop()
461-
self.process((fn,), fn, args)
508+
self.process((fn,), fn, args, delay=True)
462509

463510
def op_CALL_FUNCTION_KW(self):
464511
kwargs_keys = self.pop()
@@ -468,7 +515,7 @@ def op_CALL_FUNCTION_KW(self):
468515
args = self.pop_n(self.oparg - n_kwargs)
469516
fn = self.pop()
470517

471-
self.process((fn,), fn, args, kwargs)
518+
self.process((fn,), fn, args, kwargs, delay=True)
472519

473520
def op_CALL_FUNCTION_EX(self):
474521
has_kwarg = self.oparg & int("01", 2)
@@ -482,20 +529,21 @@ def op_CALL_FUNCTION_EX(self):
482529
fn = self.pop()
483530
if inspect.isgenerator(args):
484531
return
485-
self.process((fn,), fn, args, kwargs)
532+
self.process((fn,), fn, args, kwargs, delay=True)
486533

487534
def op_CALL_METHOD(self):
488535
args = self.pop_n(self.oparg)
489536
function_or_self = self.pop()
490537
null_or_method = self.pop()
491538
if null_or_method is self.NULL:
492539
function = function_or_self
493-
self.process((function,), function, args)
540+
self.process((function,), function, args, delay=True)
494541
else:
495542
self_ = function_or_self
496543
method = null_or_method
497544
self.process(
498545
(self_,), method, itertools.chain((self_,), args),
546+
delay=True
499547
)
500548

501549

@@ -548,6 +596,7 @@ class Tracer:
548596
calls_to_modules: List[str]
549597
# the modules we should trace calls from
550598
calls_from_modules: List[str]
599+
previous_stack: Optional[Stack] = None
551600

552601
def __enter__(self):
553602
sys.settrace(self)
@@ -577,7 +626,13 @@ def __call__(self, frame, event, arg) -> Optional[Tracer]:
577626
return None
578627

579628
if self.should_trace_frame(frame):
580-
Stack(self, frame)()
629+
stack = Stack(
630+
self,
631+
frame,
632+
previous_stack=self.previous_stack,
633+
)
634+
stack()
635+
self.previous_stack = stack if stack.log_call_args else None
581636
return None
582637

583638
def should_trace_frame(self, frame) -> bool:

record_api/infer_apis.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,16 @@ def __main__():
4242

4343

4444
def parse_line(
45-
n: int, function: object, params=None, bound_params=None,
45+
n: int,
46+
function: object,
47+
params=None,
48+
bound_params=None,
49+
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = None
4650
) -> typing.Optional[API]:
4751
if bound_params is not None:
48-
signature = Signature.from_bound_params(**bound_params)
52+
signature = Signature.from_bound_params(**bound_params, return_type=return_type)
4953
else:
50-
signature = Signature.from_params(**params)
54+
signature = Signature.from_params(**params, return_type=return_type)
5155
signature.metadata[f"usage.{LABEL}"] = n
5256
return process_function(create_type(function), s=signature)
5357

0 commit comments

Comments
 (0)