Skip to content

Commit

Permalink
add skorch compat
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Hernandez committed Jun 25, 2018
1 parent af9f1a7 commit 95f6c6d
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 2 deletions.
38 changes: 38 additions & 0 deletions osprey/data/torch_skeleton_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
estimator:
eval: Pipeline([
('scale', RobustScaler()),
('classifier', NeuralNetClassifier(nn.Sequential(nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 10),
nn.Softmax(dim=1)),
max_epochs=10)),
])
eval_scope: ['sklearn', 'torch']

scoring: accuracy

strategy:
name: gp
params:
seeds: 5

search_space:
classifier__lr:
min: 1e-3
max: 1e-1
num: 10
type: jump
var_type: float
warp: log

cv: 5

dataset_loader:
name: sklearn_dataset
params:
method: load_digits

trials:
uri: sqlite:///osprey-trials.db

random_seed: 42
16 changes: 15 additions & 1 deletion osprey/eval_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from sklearn.base import BaseEstimator


__all__ = ['msmbuilder', 'import_all_estimators']
__all__ = ['msmbuilder', 'torch', 'import_all_estimators']


def msmbuilder():
Expand All @@ -22,6 +22,20 @@ def msmbuilder():
return scope


def torch():
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
import torch
from torch import nn
import skorch
from sklearn.pipeline import Pipeline

scope = import_all_estimators(skorch)
scope.update({'nn': nn})
scope['Pipeline'] = Pipeline
return scope


def import_all_estimators(pkg):

def estimator_in_module(mod):
Expand Down
3 changes: 2 additions & 1 deletion osprey/execute_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
'random_example': 'random_example.yaml',
'gp_example': 'sklearn_skeleton_config.yaml',
'grid_example': 'grid_example.yaml',
'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml'}
'msmb_feat_select': 'msmb_feat_select_skeleton_config.yaml',
'torch': 'torch_skeleton_config.yaml'}


def execute(args, parser):
Expand Down

0 comments on commit 95f6c6d

Please sign in to comment.