-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathpytorchtools.py
29 lines (24 loc) · 974 Bytes
/
pytorchtools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import torch
# from pytorch_toolbelt import losses as L
class EarlyStopping:
def __init__(self, patience=5, delta=0, checkpoint_path='checkpoint.pt', is_maximize=True):
self.patience, self.delta, self.checkpoint_path = patience, delta, checkpoint_path
self.counter, self.best_score = 0, None
self.is_maximize = is_maximize
def load_best_weights(self, model):
model.load_state_dict(torch.load(self.checkpoint_path))
def __call__(self, score, model):
if self.best_score is None or (
score > self.best_score + self.delta
if self.is_maximize
else score < self.best_score - self.delta
):
torch.save(model.state_dict(), self.checkpoint_path)
self.best_score, self.counter = score, 0
return 1
else:
self.counter += 1
if self.counter >= self.patience:
return 2
return 0