Skip to content

Commit

Permalink
Allow for custom experiment names
Browse files Browse the repository at this point in the history
  • Loading branch information
ex4sperans committed Nov 12, 2020
1 parent 6f88482 commit cfb64d6
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
9 changes: 7 additions & 2 deletions mag/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Experiment:

def __init__(self, config=None, resume_from=None,
logfile_name="log", experiments_dir="./experiments",
implicit_resuming=False):
implicit_resuming=False, experiment_name=None):
"""Create a new Experiment instance.
Args:
Expand All @@ -42,8 +42,12 @@ def __init__(self, config=None, resume_from=None,
experiments_dir: str, a path where experiment will be saved
implicit_resuming: bool, whether to allow resuming
even if experiment already exists
experiment_name: str
a custom experiment name that is used instead of one
generated from config parameters
"""

self._custom_experiment_name = experiment_name
self.experiments_dir = experiments_dir
self.logfile_name = logfile_name

Expand Down Expand Up @@ -139,7 +143,8 @@ def _infer_experiments_dir(self, experiment_directory):

@property
def experiment_dir(self):
return os.path.join(self.experiments_dir, self.config.identifier)
experiment_name = self._custom_experiment_name or self.config.identifier
return os.path.join(self.experiments_dir, experiment_name)

@property
def config_file(self):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ def test_experiment_initialization(nested_dict_config, tmpdir):
assert config.to_dict() == nested_dict_config


def test_experiment_initialization_with_custom_name(nested_dict_config, tmpdir):

experiments_dir = tmpdir.join("experiments").strpath

experiment = Experiment(
nested_dict_config,
experiments_dir=experiments_dir,
experiment_name="custom"
)

assert os.path.isdir(os.path.join(experiments_dir, "custom"))
assert os.path.isfile(os.path.join(experiments_dir, "custom", "config.json"))


def test_experiment_restoration(nested_dict_config, tmpdir):

experiments_dir = tmpdir.join("experiments").strpath
Expand Down

0 comments on commit cfb64d6

Please sign in to comment.