diff --git a/torchx/runner/events/__init__.py b/torchx/runner/events/__init__.py index 8fab92a10..a61f84464 100644 --- a/torchx/runner/events/__init__.py +++ b/torchx/runner/events/__init__.py @@ -90,6 +90,7 @@ def __init__( app_metadata: Optional[Dict[str, str]] = None, runcfg: Optional[str] = None, workspace: Optional[str] = None, + log_full_trace_on_error: bool = False, ) -> None: self._torchx_event: TorchxEvent = self._generate_torchx_event( api, @@ -103,6 +104,7 @@ def __init__( self._start_cpu_time_ns = 0 self._start_wall_time_ns = 0 self._start_epoch_time_usec = 0 + self.log_full_trace_on_error = log_full_trace_on_error def __enter__(self) -> "log_event": self._start_cpu_time_ns = time.process_time_ns() @@ -125,15 +127,23 @@ def __exit__( ) // 1000 if traceback_type: self._torchx_event.raw_exception = traceback.format_exc() + typ, value, tb = sys.exc_info() if tb: last_frame = traceback.extract_tb(tb)[-1] + + exception_info = { + "filename": last_frame.filename, + "lineno": last_frame.lineno, + "name": last_frame.name, + } + if self.log_full_trace_on_error: + frames = traceback.extract_stack()[:-1] + exception_info["stacktrace"] = "".join( + traceback.format_list(frames) + ) self._torchx_event.exception_source_location = json.dumps( - { - "filename": last_frame.filename, - "lineno": last_frame.lineno, - "name": last_frame.name, - } + exception_info ) if exec_type: self._torchx_event.exception_type = exec_type.__name__