Skip to content

Commit

Permalink
Merge pull request #198 from cxhernandez/seed-estimator
Browse files Browse the repository at this point in the history
pass random_seed to estimator
  • Loading branch information
cxhernandez authored Sep 1, 2016
2 parents 4aef876 + 344e4d8 commit 892d490
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
2 changes: 1 addition & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ New Features
+ Added random seed via CLI (`#196 <https://github.com/msmbuilder/osprey/pull/196>`_)
+ Added ``DSVDatasetLoader`` (`#175 <https://github.com/msmbuilder/osprey/pull/175>`_)
+ Added ``random_seed`` as a configurable parameter (`#164 <https://github.com/msmbuilder/osprey/pull/164>`_)
+

Bug Fixes
~~~~~~~~~~~~
+ Fixed issue where ``random_seed`` was not passed to estimator (`#198 <https://github.com/msmbuilder/osprey/pull/198>`_)
+ Fixed ``bokeh.io.vplot`` deprecation warning (`#192 <https://github.com/msmbuilder/osprey/pull/192>`_)
+ Fixed ungraceful failures when using GP with a single choice in
search space (`#191 <https://github.com/msmbuilder/osprey/pull/191>`_)
Expand Down
18 changes: 10 additions & 8 deletions osprey/execute_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,16 @@ def execute(args, parser):
print_header()

config = Config(args.config)
random_seed = args.seed if args.seed is not None else config.random_seed()
estimator = config.estimator()
if 'random_state' in estimator.get_params().keys():
estimator.set_params(random_state=random_seed)
np.random.seed(random_seed)
searchspace = config.search_space()
strategy = config.strategy()
config_sha1 = config.sha1()
scoring = config.scoring()
random_seed = args.seed if args.seed is not None else config.random_seed()

project_name = config.project_name()

if is_msmbuilder_estimator(estimator):
Expand Down Expand Up @@ -74,12 +78,11 @@ def signal_hander(signum, frame):

trial_id, params = initialize_trial(
strategy, searchspace, estimator, config_sha1=config_sha1,
project_name=project_name,
sessionbuilder=config.trialscontext)
project_name=project_name, sessionbuilder=config.trialscontext)

s = run_single_trial(
estimator=estimator, params=params, trial_id=trial_id,
scoring=scoring, random_seed=random_seed, X=X, y=y, cv=cv,
scoring=scoring, X=X, y=y, cv=cv,
sessionbuilder=config.trialscontext)

statuses[i] = s
Expand Down Expand Up @@ -123,15 +126,14 @@ def initialize_trial(strategy, searchspace, estimator, config_sha1,
return trial_id, params


def run_single_trial(estimator, params, trial_id, scoring, random_seed,
X, y, cv, sessionbuilder):
def run_single_trial(estimator, params, trial_id, scoring, X, y, cv,
sessionbuilder):

status = None

try:
score = fit_and_score_estimator(
estimator, params, cv=cv, scoring=scoring, random_seed=random_seed,
X=X, y=y, verbose=1)
estimator, params, cv=cv, scoring=scoring, X=X, y=y, verbose=1)
with sessionbuilder() as session:
trial = session.query(Trial).get(trial_id)
trial.mean_test_score = score['mean_test_score']
Expand Down
6 changes: 2 additions & 4 deletions osprey/fit_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@


def fit_and_score_estimator(estimator, parameters, cv, X, y=None, scoring=None,
random_seed=None, iid=True, n_jobs=1,
verbose=1, pre_dispatch='2*n_jobs'):
iid=True, n_jobs=1, verbose=1,
pre_dispatch='2*n_jobs'):
"""Fit and score an estimator with cross-validation
This function is basically a copy of sklearn's
Expand All @@ -40,8 +40,6 @@ def fit_and_score_estimator(estimator, parameters, cv, X, y=None, scoring=None,
score.
"""

np.random.seed(random_seed)

scorer = check_scoring(estimator, scoring=scoring)
n_samples = num_samples(X)
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr',
Expand Down
4 changes: 2 additions & 2 deletions osprey/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,8 @@ def _hyperopt_fmin_random_kwarg(random):
class GP(BaseStrategy):
short_name = 'gp'

def __init__(self, seeds=1, max_feval=5E4, max_iter=1E5):
def __init__(self, seed=None, seeds=1, max_feval=5E4, max_iter=1E5):
self.seed = seed
self.seeds = seeds
self.max_feval = max_feval
self.max_iter = max_iter
Expand Down Expand Up @@ -304,4 +305,3 @@ def suggest(self, history, searchspace):
# so user should pick correctly number of evaluations
self.current += 1
return self.param_grid[self.current % len(self.param_grid)]

0 comments on commit 892d490

Please sign in to comment.