From 95f6c6d954e361960fd1d044250f308968a7b797 Mon Sep 17 00:00:00 2001 From: Carlos Hernandez Date: Sun, 24 Jun 2018 17:35:19 -0700 Subject: [PATCH] add skorch compat --- osprey/data/torch_skeleton_config.yaml | 38 ++++++++++++++++++++++++++ osprey/eval_scopes.py | 16 ++++++++++- osprey/execute_skeleton.py | 3 +- 3 files changed, 55 insertions(+), 2 deletions(-) create mode 100644 osprey/data/torch_skeleton_config.yaml diff --git a/osprey/data/torch_skeleton_config.yaml b/osprey/data/torch_skeleton_config.yaml new file mode 100644 index 0000000..631f39c --- /dev/null +++ b/osprey/data/torch_skeleton_config.yaml @@ -0,0 +1,38 @@ +estimator: + eval: Pipeline([ + ('scale', RobustScaler()), + ('classifier', NeuralNetClassifier(nn.Sequential(nn.Linear(64, 32), + nn.ReLU(), + nn.Linear(32, 10), + nn.Softmax(dim=1)), + max_epochs=10)), + ]) + eval_scope: ['sklearn', 'torch'] + +scoring: accuracy + +strategy: + name: gp + params: + seeds: 5 + +search_space: + classifier__lr: + min: 1e-3 + max: 1e-1 + num: 10 + type: jump + var_type: float + warp: log + +cv: 5 + +dataset_loader: + name: sklearn_dataset + params: + method: load_digits + +trials: + uri: sqlite:///osprey-trials.db + +random_seed: 42 diff --git a/osprey/eval_scopes.py b/osprey/eval_scopes.py index b55509f..e2af20b 100644 --- a/osprey/eval_scopes.py +++ b/osprey/eval_scopes.py @@ -8,7 +8,7 @@ from sklearn.base import BaseEstimator -__all__ = ['msmbuilder', 'import_all_estimators'] +__all__ = ['msmbuilder', 'torch', 'import_all_estimators'] def msmbuilder(): @@ -22,6 +22,20 @@ def msmbuilder(): return scope +def torch(): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + import torch + from torch import nn + import skorch + from sklearn.pipeline import Pipeline + + scope = import_all_estimators(skorch) + scope.update({'nn': nn}) + scope['Pipeline'] = Pipeline + return scope + + def import_all_estimators(pkg): def estimator_in_module(mod): diff --git a/osprey/execute_skeleton.py b/osprey/execute_skeleton.py index 0b9684c..d15cc53 100644 --- a/osprey/execute_skeleton.py +++ b/osprey/execute_skeleton.py @@ -8,7 +8,8 @@ 'random_example': 'random_example.yaml', 'gp_example': 'sklearn_skeleton_config.yaml', 'grid_example': 'grid_example.yaml', - 'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml'} + 'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml', + 'torch': 'torch_skeleton_config.yaml'} def execute(args, parser):