Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add skorch compat #244

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions devtools/conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ test:
- nose
- nose-timer
- gpy
- skorch
- pytorch
- msmbuilder
- msmb_data
- mdtraj
Expand Down
3 changes: 1 addition & 2 deletions devtools/travis-ci/build_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,17 @@ conda create --yes -n docenv python=$CONDA_PY
source activate docenv
conda install -yq --use-local osprey


# Install doc requirements
conda install --yes --file docs/requirements.txt


# We don't use conda for these:
# sphinx_rtd_theme's latest releases are not available
# neither is msmb_theme
# neither is sphinx > 1.3.1 (fix #1892 autodoc problem)
pip install -I sphinx==1.3.5
pip install -I sphinx_rtd_theme==0.1.9 msmb_theme==1.2.0


# Make docs
cd docs && make html && cd -

Expand Down
38 changes: 38 additions & 0 deletions osprey/data/torch_skeleton_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
estimator:
eval: Pipeline([
('scale', RobustScaler()),
('classifier', NeuralNetClassifier(nn.Sequential(nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
nn.Softmax(dim=1)),
max_epochs=10)),
])
eval_scope: ['sklearn', 'torch']

scoring: accuracy

strategy:
name: gp
params:
seeds: 5

search_space:
classifier__lr:
min: 1e-3
max: 1e-1
num: 10
type: jump
var_type: float
warp: log

cv: 5

dataset_loader:
name: sklearn_dataset
params:
method: load_digits

trials:
uri: sqlite:///osprey-trials.db

random_seed: 42
16 changes: 15 additions & 1 deletion osprey/eval_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sklearn.base import BaseEstimator

__all__ = ['msmbuilder', 'import_all_estimators']
__all__ = ['msmbuilder', 'torch', 'import_all_estimators']


def msmbuilder():
Expand All @@ -21,6 +21,20 @@ def msmbuilder():
return scope


def torch():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
from torch import nn
import skorch
from sklearn.pipeline import Pipeline

scope = import_all_estimators(skorch)
scope.update({'nn': nn})
scope['Pipeline'] = Pipeline
return scope


def import_all_estimators(pkg):
def estimator_in_module(mod):
for name, obj in inspect.getmembers(mod):
Expand Down
3 changes: 2 additions & 1 deletion osprey/execute_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
'random_example': 'random_example.yaml',
'bayes_example': 'sklearn_skeleton_config.yaml',
'grid_example': 'grid_example.yaml',
'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml'}
'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml',
'torch': 'torch_skeleton_config.yaml'}


def execute(args, parser):
Expand Down
31 changes: 31 additions & 0 deletions osprey/tests/test_cli_worker_and_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@
except:
HAVE_MSMBUILDER = False

try:
__import__('skorch')
HAVE_SKORCH = True
except:
HAVE_SKORCH = False


OSPREY_BIN = find_executable('osprey')


Expand Down Expand Up @@ -151,6 +158,30 @@ def test_gp_example():
shutil.rmtree(dirname)


@skipif(not HAVE_SKORCH, 'this test requires Skorch')
def test_torch_example():
assert OSPREY_BIN is not None
cwd = os.path.abspath(os.curdir)
dirname = tempfile.mkdtemp()

try:
os.chdir(dirname)
subprocess.check_call([OSPREY_BIN, 'skeleton', '-t', 'torch',
'-f', 'config.yaml'])
subprocess.check_call([OSPREY_BIN, 'worker', 'config.yaml', '-n', '1'])
assert os.path.exists('osprey-trials.db')

subprocess.check_call([OSPREY_BIN, 'current_best', 'config.yaml'])

yield _test_dump_1

yield _test_plot_1

finally:
os.chdir(cwd)
shutil.rmtree(dirname)


def test_grid_example():
assert OSPREY_BIN is not None
cwd = os.path.abspath(os.curdir)
Expand Down