diff --git a/devtools/conda-recipe/meta.yaml b/devtools/conda-recipe/meta.yaml index b97d318..3f9ec5f 100644 --- a/devtools/conda-recipe/meta.yaml +++ b/devtools/conda-recipe/meta.yaml @@ -33,6 +33,7 @@ test: - nose - nose-timer - gpy + - skorch - msmbuilder - msmb_data - mdtraj diff --git a/osprey/tests/test_cli_worker_and_dump.py b/osprey/tests/test_cli_worker_and_dump.py index c71f105..21416f6 100644 --- a/osprey/tests/test_cli_worker_and_dump.py +++ b/osprey/tests/test_cli_worker_and_dump.py @@ -15,6 +15,13 @@ except: HAVE_MSMBUILDER = False +try: + __import__('skorch') + HAVE_SKORCH = True +except: + HAVE_SKORCH = False + + OSPREY_BIN = find_executable('osprey') @@ -136,6 +143,30 @@ def test_gp_example(): shutil.rmtree(dirname) +@skipif(not HAVE_SKORCH, 'this test requires Skorch') +def test_torch_example(): + assert OSPREY_BIN is not None + cwd = os.path.abspath(os.curdir) + dirname = tempfile.mkdtemp() + + try: + os.chdir(dirname) + subprocess.check_call([OSPREY_BIN, 'skeleton', '-t', 'torch', + '-f', 'config.yaml']) + subprocess.check_call([OSPREY_BIN, 'worker', 'config.yaml', '-n', '1']) + assert os.path.exists('osprey-trials.db') + + subprocess.check_call([OSPREY_BIN, 'current_best', 'config.yaml']) + + yield _test_dump_1 + + yield _test_plot_1 + + finally: + os.chdir(cwd) + shutil.rmtree(dirname) + + def test_grid_example(): assert OSPREY_BIN is not None cwd = os.path.abspath(os.curdir)