29
29
context_manager : Optional [ContextManager ] = None
30
30
write_line : Optional [Callable [[dict ], None ]] = None
31
31
32
+ FUNCTION_CALL_OP_NAMES = {
33
+ "CALL_METHOD" ,
34
+ "CALL_FUNCTION" ,
35
+ "CALL_FUNCTION_KW" ,
36
+ "CALL_FUNCTION_EX" ,
37
+ "LOAD_ATTR" ,
38
+ "BINARY_SUBSCR" ,
39
+ }
40
+
32
41
33
42
def get_tracer () -> Tracer :
34
43
global TRACER
@@ -273,6 +282,7 @@ def log_call(
273
282
fn : Callable ,
274
283
args : Iterable = (),
275
284
kwargs : Mapping [str , Any ] = {},
285
+ return_type : Any = None ,
276
286
) -> None :
277
287
bound = Bound .create (fn , args , kwargs )
278
288
line : Dict = {"location" : location , "function" : preprocess (fn )}
@@ -284,6 +294,8 @@ def log_call(
284
294
line ["params" ]["kwargs" ] = {k : preprocess (v ) for k , v in kwargs .items ()}
285
295
else :
286
296
line ["bound_params" ] = bound .as_dict ()
297
+ if return_type :
298
+ line ['return_type' ] = return_type
287
299
assert write_line
288
300
write_line (line )
289
301
@@ -295,11 +307,16 @@ class Stack:
295
307
NULL : ClassVar [object ] = object ()
296
308
current_i : int = dataclasses .field (init = False , default = 0 )
297
309
opcode : int = dataclasses .field (init = False )
310
+ previous_stack : Optional [Stack ] = None
311
+ log_call_args : Tuple = ()
298
312
299
313
def __post_init__ (self ):
300
314
self .op_stack = get_stack .OpStack (self .frame )
301
315
self .opcode = self .frame .f_code .co_code [self .frame .f_lasti ]
302
316
317
+ if self .previous_stack and self .previous_stack .previous_stack :
318
+ del self .previous_stack .previous_stack
319
+
303
320
@property
304
321
def oparg (self ):
305
322
# sort of replicates logic in dis._unpack_opargs but doesn't account for extended
@@ -360,14 +377,24 @@ def pop_n(self, n: int) -> List:
360
377
return l
361
378
362
379
def process (
363
- self , keyed_args : Tuple , fn : Callable , args : Iterable , kwargs : Mapping = {}
380
+ self ,
381
+ keyed_args : Tuple ,
382
+ fn : Callable ,
383
+ args : Iterable ,
384
+ kwargs : Mapping = {},
385
+ delay : bool = False
364
386
) -> None :
365
- # Note: This take args as an iterable, instead of as a varargs, so that if we don't trace we don't have to expand the iterable
387
+
388
+ # Note: This take args as an iterable, instead of as a varargs, so that if
389
+ # we don't trace we don't have to expand the iterable
366
390
if self .tracer .should_trace (* keyed_args ):
367
391
filename = self .frame .f_code .co_filename
368
392
line = self .frame .f_lineno
369
393
# Don't pass kwargs if not used, so we can more easily test mock calls
370
- log_call (f"{ filename } :{ line } " , fn , tuple (args ), * ((kwargs ,) if kwargs else ()))
394
+ if not delay :
395
+ log_call (f"{ filename } :{ line } " , fn , tuple (args ), * ((kwargs ,) if kwargs else ()))
396
+ else :
397
+ self .log_call_args = (filename , line , fn , tuple (args ), kwargs )
371
398
372
399
def __call__ (self ) -> None :
373
400
"""
@@ -383,14 +410,34 @@ def __call__(self) -> None:
383
410
(self .TOS , self .TOS1 ), BINARY_OPS [opname ], (self .TOS1 , self .TOS )
384
411
)
385
412
413
+ if self .previous_stack and self .previous_stack .opname in FUNCTION_CALL_OP_NAMES :
414
+ self .log_called_method ()
415
+
386
416
method_name = f"op_{ opname } "
387
417
if hasattr (self , method_name ):
388
418
getattr (self , method_name )()
389
419
return None
390
420
421
+ def log_called_method (self ):
422
+ if self .previous_stack .log_call_args :
423
+ tos = self .TOS
424
+ if type (tos ) is type and issubclass (tos , Exception ):
425
+ # Don't record exception
426
+ return
427
+ return_type = type (tos ) if type (tos ) != type else tos
428
+ filename , line , fn , args , * kwargs = self .previous_stack .log_call_args
429
+ kwargs = kwargs [0 ] if kwargs else {}
430
+ log_call (
431
+ f"{ filename } :{ line } " ,
432
+ fn ,
433
+ tuple (args ),
434
+ * ((kwargs ,) if kwargs else ()),
435
+ return_type = return_type ,
436
+ )
437
+
391
438
# special case subscr b/c we only check first arg, not both
392
439
def op_BINARY_SUBSCR (self ):
393
- self .process ((self .TOS1 ,), op .getitem , (self .TOS1 , self .TOS ))
440
+ self .process ((self .TOS1 ,), op .getitem , (self .TOS1 , self .TOS ), delay = True )
394
441
395
442
def op_STORE_SUBSCR (self ):
396
443
self .process ((self .TOS1 ,), op .setitem , (self .TOS1 , self .TOS , self .TOS2 ))
@@ -399,7 +446,7 @@ def op_DELETE_SUBSCR(self):
399
446
self .process ((self .TOS1 ,), op .delitem , (self .TOS1 , self .TOS ))
400
447
401
448
def op_LOAD_ATTR (self ):
402
- self .process ((self .TOS ,), getattr , (self .TOS , self .opvalname ))
449
+ self .process ((self .TOS ,), getattr , (self .TOS , self .opvalname ), delay = True )
403
450
404
451
def op_STORE_ATTR (self ):
405
452
self .process ((self .TOS ,), setattr , (self .TOS , self .opvalname , self .TOS1 ))
@@ -458,7 +505,7 @@ def op_COMPARE_OP(self):
458
505
def op_CALL_FUNCTION (self ):
459
506
args = self .pop_n (self .oparg )
460
507
fn = self .pop ()
461
- self .process ((fn ,), fn , args )
508
+ self .process ((fn ,), fn , args , delay = True )
462
509
463
510
def op_CALL_FUNCTION_KW (self ):
464
511
kwargs_keys = self .pop ()
@@ -468,7 +515,7 @@ def op_CALL_FUNCTION_KW(self):
468
515
args = self .pop_n (self .oparg - n_kwargs )
469
516
fn = self .pop ()
470
517
471
- self .process ((fn ,), fn , args , kwargs )
518
+ self .process ((fn ,), fn , args , kwargs , delay = True )
472
519
473
520
def op_CALL_FUNCTION_EX (self ):
474
521
has_kwarg = self .oparg & int ("01" , 2 )
@@ -482,20 +529,21 @@ def op_CALL_FUNCTION_EX(self):
482
529
fn = self .pop ()
483
530
if inspect .isgenerator (args ):
484
531
return
485
- self .process ((fn ,), fn , args , kwargs )
532
+ self .process ((fn ,), fn , args , kwargs , delay = True )
486
533
487
534
def op_CALL_METHOD (self ):
488
535
args = self .pop_n (self .oparg )
489
536
function_or_self = self .pop ()
490
537
null_or_method = self .pop ()
491
538
if null_or_method is self .NULL :
492
539
function = function_or_self
493
- self .process ((function ,), function , args )
540
+ self .process ((function ,), function , args , delay = True )
494
541
else :
495
542
self_ = function_or_self
496
543
method = null_or_method
497
544
self .process (
498
545
(self_ ,), method , itertools .chain ((self_ ,), args ),
546
+ delay = True
499
547
)
500
548
501
549
@@ -548,6 +596,7 @@ class Tracer:
548
596
calls_to_modules : List [str ]
549
597
# the modules we should trace calls from
550
598
calls_from_modules : List [str ]
599
+ previous_stack : Optional [Stack ] = None
551
600
552
601
def __enter__ (self ):
553
602
sys .settrace (self )
@@ -577,7 +626,13 @@ def __call__(self, frame, event, arg) -> Optional[Tracer]:
577
626
return None
578
627
579
628
if self .should_trace_frame (frame ):
580
- Stack (self , frame )()
629
+ stack = Stack (
630
+ self ,
631
+ frame ,
632
+ previous_stack = self .previous_stack ,
633
+ )
634
+ stack ()
635
+ self .previous_stack = stack if stack .log_call_args else None
581
636
return None
582
637
583
638
def should_trace_frame (self , frame ) -> bool :
0 commit comments