diff --git a/fairseq/progress_bar.py b/fairseq/progress_bar.py index f7a2154cd4..f72136d475 100644 --- a/fairseq/progress_bar.py +++ b/fairseq/progress_bar.py @@ -254,6 +254,13 @@ def print(self, stats, tag=None, step=None): self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix)) +try: + from tensorboardX import SummaryWriter + _tensorboard_writers = {} +except ImportError: + SummaryWriter = None + + class tensorboard_log_wrapper(progress_bar): """Log to tensorboard.""" @@ -262,27 +269,21 @@ def __init__(self, wrapped_bar, tensorboard_logdir, args): self.tensorboard_logdir = tensorboard_logdir self.args = args - try: - from tensorboardX import SummaryWriter - self.SummaryWriter = SummaryWriter - self._writers = {} - except ImportError: + if SummaryWriter is None: logger.warning( - "tensorboard or required dependencies not found, " - "please see README for using tensorboard. (e.g. pip install tensorboardX)" + "tensorboard or required dependencies not found, please see README " + "for using tensorboard. (e.g. pip install tensorboardX)" ) - self.SummaryWriter = None def _writer(self, key): - if self.SummaryWriter is None: + if SummaryWriter is None: return None - if key not in self._writers: - self._writers[key] = self.SummaryWriter( - os.path.join(self.tensorboard_logdir, key), - ) - self._writers[key].add_text('args', str(vars(self.args))) - self._writers[key].add_text('sys.argv', " ".join(sys.argv)) - return self._writers[key] + _writers = _tensorboard_writers + if key not in _writers: + _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) + _writers[key].add_text('args', str(vars(self.args))) + _writers[key].add_text('sys.argv', " ".join(sys.argv)) + return _writers[key] def __iter__(self): return iter(self.wrapped_bar) @@ -297,11 +298,6 @@ def print(self, stats, tag=None, step=None): self._log_to_tensorboard(stats, tag, step) self.wrapped_bar.print(stats, tag=tag, step=step) - def __exit__(self, *exc): - for writer in getattr(self, '_writers', {}).values(): - writer.close() - return False - def _log_to_tensorboard(self, stats, tag=None, step=None): writer = self._writer(tag or '') if writer is None: