diff --git a/mag/experiment.py b/mag/experiment.py index 1ec66e1..63dfd0e 100644 --- a/mag/experiment.py +++ b/mag/experiment.py @@ -27,7 +27,7 @@ class Experiment: def __init__(self, config=None, resume_from=None, logfile_name="log", experiments_dir="./experiments", - implicit_resuming=False): + implicit_resuming=False, experiment_name=None): """Create a new Experiment instance. Args: @@ -42,8 +42,12 @@ def __init__(self, config=None, resume_from=None, experiments_dir: str, a path where experiment will be saved implicit_resuming: bool, whether to allow resuming even if experiment already exists + experiment_name: str + a custom experiment name that is used instead of one + generated from config parameters """ + self._custom_experiment_name = experiment_name self.experiments_dir = experiments_dir self.logfile_name = logfile_name @@ -139,7 +143,8 @@ def _infer_experiments_dir(self, experiment_directory): @property def experiment_dir(self): - return os.path.join(self.experiments_dir, self.config.identifier) + experiment_name = self._custom_experiment_name or self.config.identifier + return os.path.join(self.experiments_dir, experiment_name) @property def config_file(self): diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 953b481..3b28648 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -47,6 +47,20 @@ def test_experiment_initialization(nested_dict_config, tmpdir): assert config.to_dict() == nested_dict_config +def test_experiment_initialization_with_custom_name(nested_dict_config, tmpdir): + + experiments_dir = tmpdir.join("experiments").strpath + + experiment = Experiment( + nested_dict_config, + experiments_dir=experiments_dir, + experiment_name="custom" + ) + + assert os.path.isdir(os.path.join(experiments_dir, "custom")) + assert os.path.isfile(os.path.join(experiments_dir, "custom", "config.json")) + + def test_experiment_restoration(nested_dict_config, tmpdir): experiments_dir = tmpdir.join("experiments").strpath