@@ -594,6 +594,8 @@ def __init__(self, bridge, learning_rate=0.3, max_iters=50, max_depth=6,
594
594
self ._role = self ._bridge .role
595
595
self ._bridge .connect ()
596
596
self ._make_key_pair ()
597
+ else :
598
+ self ._role = 'local'
597
599
598
600
@property
599
601
def loss (self ):
@@ -616,7 +618,7 @@ def _verify_params(self, example_ids, is_training, validation=False):
616
618
return
617
619
618
620
self ._bridge .start (self ._bridge .new_iter_id ())
619
- if self ._bridge . role == 'leader' :
621
+ if self ._role == 'leader' :
620
622
msg = tree_pb2 .VerifyParams (
621
623
example_ids = example_ids ,
622
624
learning_rate = self ._learning_rate ,
@@ -704,7 +706,7 @@ def batch_predict(self, features, get_raw_score=False, example_ids=None):
704
706
return self ._batch_predict_local (features , get_raw_score )
705
707
706
708
self ._verify_params (example_ids , False )
707
- if self ._bridge . role == 'leader' :
709
+ if self ._role == 'leader' :
708
710
return self ._batch_predict_leader (features , get_raw_score )
709
711
return self ._batch_predict_follower (features , get_raw_score )
710
712
@@ -830,7 +832,7 @@ def fit(self, features, labels=None,
830
832
tree , raw_prediction = self ._fit_one_round_local (
831
833
sum_prediction , binned , labels )
832
834
sum_prediction += raw_prediction
833
- elif self ._bridge . role == 'leader' :
835
+ elif self ._role == 'leader' :
834
836
tree , raw_prediction = self ._fit_one_round_leader (
835
837
sum_prediction , binned , labels )
836
838
sum_prediction += raw_prediction
@@ -853,7 +855,7 @@ def fit(self, features, labels=None,
853
855
self .save_model (filename )
854
856
855
857
# save output
856
- if self ._bridge . role != 'follower' and output_path is not None :
858
+ if self ._role != 'follower' and output_path is not None :
857
859
pred = self ._loss .predict (sum_prediction )
858
860
metrics = self ._loss .metrics (pred , labels )
859
861
self ._write_training_log (
@@ -863,7 +865,7 @@ def fit(self, features, labels=None,
863
865
if validation_features is not None :
864
866
val_pred = self .batch_predict (
865
867
validation_features , example_ids = validation_example_ids )
866
- if self ._bridge . role != 'follower' :
868
+ if self ._role != 'follower' :
867
869
metrics = self ._loss .metrics (val_pred , validation_labels )
868
870
logging .info (
869
871
"Validation metrics for iter %d: %s" , num_iter , metrics )
0 commit comments