Skip to content

Commit

Permalink
Use global Tensorboard SummaryWriter objects to avoid writing too man…
Browse files Browse the repository at this point in the history
…y files (facebookresearch#1745)

Summary: Pull Request resolved: facebookresearch#1745

Differential Revision: D20097902

Pulled By: myleott

fbshipit-source-id: cad90f0e074ac1e58a03846bb6fdc25703e6c8f8
  • Loading branch information
Myle Ott authored and facebook-github-bot committed Feb 25, 2020
1 parent f1a9ce8 commit d88b91b
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions fairseq/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,13 @@ def print(self, stats, tag=None, step=None):
self.tqdm.write('{} | {}'.format(self.tqdm.desc, postfix))


try:
from tensorboardX import SummaryWriter
_tensorboard_writers = {}
except ImportError:
SummaryWriter = None


class tensorboard_log_wrapper(progress_bar):
"""Log to tensorboard."""

Expand All @@ -262,27 +269,21 @@ def __init__(self, wrapped_bar, tensorboard_logdir, args):
self.tensorboard_logdir = tensorboard_logdir
self.args = args

try:
from tensorboardX import SummaryWriter
self.SummaryWriter = SummaryWriter
self._writers = {}
except ImportError:
if SummaryWriter is None:
logger.warning(
"tensorboard or required dependencies not found, "
"please see README for using tensorboard. (e.g. pip install tensorboardX)"
"tensorboard or required dependencies not found, please see README "
"for using tensorboard. (e.g. pip install tensorboardX)"
)
self.SummaryWriter = None

def _writer(self, key):
if self.SummaryWriter is None:
if SummaryWriter is None:
return None
if key not in self._writers:
self._writers[key] = self.SummaryWriter(
os.path.join(self.tensorboard_logdir, key),
)
self._writers[key].add_text('args', str(vars(self.args)))
self._writers[key].add_text('sys.argv', " ".join(sys.argv))
return self._writers[key]
_writers = _tensorboard_writers
if key not in _writers:
_writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key))
_writers[key].add_text('args', str(vars(self.args)))
_writers[key].add_text('sys.argv', " ".join(sys.argv))
return _writers[key]

def __iter__(self):
return iter(self.wrapped_bar)
Expand All @@ -297,11 +298,6 @@ def print(self, stats, tag=None, step=None):
self._log_to_tensorboard(stats, tag, step)
self.wrapped_bar.print(stats, tag=tag, step=step)

def __exit__(self, *exc):
for writer in getattr(self, '_writers', {}).values():
writer.close()
return False

def _log_to_tensorboard(self, stats, tag=None, step=None):
writer = self._writer(tag or '')
if writer is None:
Expand Down

0 comments on commit d88b91b

Please sign in to comment.