Skip to content

Commit

Permalink
Merge pull request #78 from data-apis/record-returned-value
Browse files Browse the repository at this point in the history
Record function returned types
  • Loading branch information
saulshanabrook authored Oct 5, 2020
2 parents 8694ef8 + 914972c commit cb9f5ad
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 42 deletions.
25 changes: 24 additions & 1 deletion record_api/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ class Signature(BaseModel):
var_kw: typing.Optional[typing.Tuple[str, Type]] = None

metadata: typing.Dict[str, int] = pydantic.Field(default_factory=dict)
return_type: OutputType = None

@pydantic.validator("pos_only_required")
@classmethod
Expand Down Expand Up @@ -408,6 +409,13 @@ def validate_keys_unique(cls, values) -> None:
raise ValueError(repr(all_keys))
return values

@property
def return_type_annotation(self) -> typing.Optional[cst.Annotation]:
return_type_annotation = None
if self.return_type:
return_type_annotation = cst.Annotation(self.return_type.annotation)
return return_type_annotation

def function_def(
self,
name: str,
Expand All @@ -427,6 +435,7 @@ def function_def(
[cst.SimpleStatementLine([s]) for s in self.body(indent)]
),
decorators,
self.return_type_annotation
)

def body(self, indent: int) -> typing.Iterable[cst.BaseSmallStatement]:
Expand Down Expand Up @@ -508,13 +517,17 @@ def initial_args(self) -> typing.Iterator[Type]:

@classmethod
def from_params(
cls, args: typing.List[object] = [], kwargs: typing.Dict[str, object] = {}
cls,
args: typing.List[object] = [],
kwargs: typing.Dict[str, object] = {},
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = {},
) -> Signature:
# If we don't know what the args/kwargs are, assume the args are positional only
# and the kwargs and keyword only
return cls(
pos_only_required={f"_{i}": create_type(v) for i, v in enumerate(args)},
kw_only_required={k: create_type(v) for k, v in kwargs.items()},
return_type=create_type(return_type) if return_type else None
)

@classmethod
Expand All @@ -525,6 +538,7 @@ def from_bound_params(
var_pos: typing.Optional[typing.Tuple[str, typing.List[object]]] = None,
kw_only: typing.Dict[str, object] = {},
var_kw: typing.Optional[typing.Tuple[str, typing.Dict[str, object]]] = None,
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = {},
) -> Signature:
return cls(
pos_only_required={k: create_type(v) for k, v in pos_only},
Expand All @@ -538,6 +552,7 @@ def from_bound_params(
if var_kw
else None
),
return_type=create_type(return_type) if return_type else None
)

def content_equal(self, other: Signature) -> bool:
Expand All @@ -557,6 +572,7 @@ def __ior__(self, other: Signature) -> Signature:
self._copy_var_pos(other)
self._copy_kw_only(other)
self._copy_var_kw(other)
self._copy_return_type(other)

update_add(self.metadata, other.metadata)
self._trim_positional_only_args()
Expand Down Expand Up @@ -718,6 +734,13 @@ def _copy_var_kw(self, other: Signature) -> None:
else (self.var_kw or other.var_kw)
)

def _copy_return_type(self, other: Signature) -> None:
self.return_type = (
unify((self.return_type, other.return_type,))
if self.return_type and other.return_type
else (self.return_type or other.return_type)
)


API.update_forward_refs()
Class.update_forward_refs()
Expand Down
75 changes: 65 additions & 10 deletions record_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
context_manager: Optional[ContextManager] = None
write_line: Optional[Callable[[dict], None]] = None

FUNCTION_CALL_OP_NAMES = {
"CALL_METHOD",
"CALL_FUNCTION",
"CALL_FUNCTION_KW",
"CALL_FUNCTION_EX",
"LOAD_ATTR",
"BINARY_SUBSCR",
}


def get_tracer() -> Tracer:
global TRACER
Expand Down Expand Up @@ -273,6 +282,7 @@ def log_call(
fn: Callable,
args: Iterable = (),
kwargs: Mapping[str, Any] = {},
return_type: Any = None,
) -> None:
bound = Bound.create(fn, args, kwargs)
line: Dict = {"location": location, "function": preprocess(fn)}
Expand All @@ -284,6 +294,8 @@ def log_call(
line["params"]["kwargs"] = {k: preprocess(v) for k, v in kwargs.items()}
else:
line["bound_params"] = bound.as_dict()
if return_type:
line['return_type'] = return_type
assert write_line
write_line(line)

Expand All @@ -295,11 +307,16 @@ class Stack:
NULL: ClassVar[object] = object()
current_i: int = dataclasses.field(init=False, default=0)
opcode: int = dataclasses.field(init=False)
previous_stack: Optional[Stack] = None
log_call_args: Tuple = ()

def __post_init__(self):
self.op_stack = get_stack.OpStack(self.frame)
self.opcode = self.frame.f_code.co_code[self.frame.f_lasti]

if self.previous_stack and self.previous_stack.previous_stack:
del self.previous_stack.previous_stack

@property
def oparg(self):
# sort of replicates logic in dis._unpack_opargs but doesn't account for extended
Expand Down Expand Up @@ -360,14 +377,24 @@ def pop_n(self, n: int) -> List:
return l

def process(
self, keyed_args: Tuple, fn: Callable, args: Iterable, kwargs: Mapping = {}
self,
keyed_args: Tuple,
fn: Callable,
args: Iterable,
kwargs: Mapping = {},
delay: bool = False
) -> None:
# Note: This take args as an iterable, instead of as a varargs, so that if we don't trace we don't have to expand the iterable

# Note: This take args as an iterable, instead of as a varargs, so that if
# we don't trace we don't have to expand the iterable
if self.tracer.should_trace(*keyed_args):
filename = self.frame.f_code.co_filename
line = self.frame.f_lineno
# Don't pass kwargs if not used, so we can more easily test mock calls
log_call(f"{filename}:{line}", fn, tuple(args), *((kwargs,) if kwargs else ()))
if not delay:
log_call(f"{filename}:{line}", fn, tuple(args), *((kwargs,) if kwargs else ()))
else:
self.log_call_args = (filename, line, fn, tuple(args), kwargs)

def __call__(self) -> None:
"""
Expand All @@ -383,14 +410,34 @@ def __call__(self) -> None:
(self.TOS, self.TOS1), BINARY_OPS[opname], (self.TOS1, self.TOS)
)

if self.previous_stack and self.previous_stack.opname in FUNCTION_CALL_OP_NAMES:
self.log_called_method()

method_name = f"op_{opname}"
if hasattr(self, method_name):
getattr(self, method_name)()
return None

def log_called_method(self):
if self.previous_stack.log_call_args:
tos = self.TOS
if type(tos) is type and issubclass(tos, Exception):
# Don't record exception
return
return_type = type(tos) if type(tos) != type else tos
filename, line, fn, args, *kwargs = self.previous_stack.log_call_args
kwargs = kwargs[0] if kwargs else {}
log_call(
f"{filename}:{line}",
fn,
tuple(args),
*((kwargs,) if kwargs else ()),
return_type=return_type,
)

# special case subscr b/c we only check first arg, not both
def op_BINARY_SUBSCR(self):
self.process((self.TOS1,), op.getitem, (self.TOS1, self.TOS))
self.process((self.TOS1,), op.getitem, (self.TOS1, self.TOS), delay=True)

def op_STORE_SUBSCR(self):
self.process((self.TOS1,), op.setitem, (self.TOS1, self.TOS, self.TOS2))
Expand All @@ -399,7 +446,7 @@ def op_DELETE_SUBSCR(self):
self.process((self.TOS1,), op.delitem, (self.TOS1, self.TOS))

def op_LOAD_ATTR(self):
self.process((self.TOS,), getattr, (self.TOS, self.opvalname))
self.process((self.TOS,), getattr, (self.TOS, self.opvalname), delay=True)

def op_STORE_ATTR(self):
self.process((self.TOS,), setattr, (self.TOS, self.opvalname, self.TOS1))
Expand Down Expand Up @@ -458,7 +505,7 @@ def op_COMPARE_OP(self):
def op_CALL_FUNCTION(self):
args = self.pop_n(self.oparg)
fn = self.pop()
self.process((fn,), fn, args)
self.process((fn,), fn, args, delay=True)

def op_CALL_FUNCTION_KW(self):
kwargs_keys = self.pop()
Expand All @@ -468,7 +515,7 @@ def op_CALL_FUNCTION_KW(self):
args = self.pop_n(self.oparg - n_kwargs)
fn = self.pop()

self.process((fn,), fn, args, kwargs)
self.process((fn,), fn, args, kwargs, delay=True)

def op_CALL_FUNCTION_EX(self):
has_kwarg = self.oparg & int("01", 2)
Expand All @@ -482,20 +529,21 @@ def op_CALL_FUNCTION_EX(self):
fn = self.pop()
if inspect.isgenerator(args):
return
self.process((fn,), fn, args, kwargs)
self.process((fn,), fn, args, kwargs, delay=True)

def op_CALL_METHOD(self):
args = self.pop_n(self.oparg)
function_or_self = self.pop()
null_or_method = self.pop()
if null_or_method is self.NULL:
function = function_or_self
self.process((function,), function, args)
self.process((function,), function, args, delay=True)
else:
self_ = function_or_self
method = null_or_method
self.process(
(self_,), method, itertools.chain((self_,), args),
delay=True
)


Expand Down Expand Up @@ -548,6 +596,7 @@ class Tracer:
calls_to_modules: List[str]
# the modules we should trace calls from
calls_from_modules: List[str]
previous_stack: Optional[Stack] = None

def __enter__(self):
sys.settrace(self)
Expand Down Expand Up @@ -577,7 +626,13 @@ def __call__(self, frame, event, arg) -> Optional[Tracer]:
return None

if self.should_trace_frame(frame):
Stack(self, frame)()
stack = Stack(
self,
frame,
previous_stack=self.previous_stack,
)
stack()
self.previous_stack = stack if stack.log_call_args else None
return None

def should_trace_frame(self, frame) -> bool:
Expand Down
10 changes: 7 additions & 3 deletions record_api/infer_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ def __main__():


def parse_line(
n: int, function: object, params=None, bound_params=None,
n: int,
function: object,
params=None,
bound_params=None,
return_type: typing.Optional[typing.Dict[str, typing.Union[str, typing.Dict]]] = None
) -> typing.Optional[API]:
if bound_params is not None:
signature = Signature.from_bound_params(**bound_params)
signature = Signature.from_bound_params(**bound_params, return_type=return_type)
else:
signature = Signature.from_params(**params)
signature = Signature.from_params(**params, return_type=return_type)
signature.metadata[f"usage.{LABEL}"] = n
return process_function(create_type(function), s=signature)

Expand Down
Loading

0 comments on commit cb9f5ad

Please sign in to comment.