-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathue_estimator_sto.py
57 lines (41 loc) · 1.81 KB
/
ue_estimator_sto.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from tqdm import tqdm
import time
import logging
log = logging.getLogger()
class UeEstimatorSTO:
def __init__(self, cls, ue_args, eval_metric, calibration_dataset, train_dataset):
self.cls = cls
self.ue_args = ue_args
self.calibration_dataset = calibration_dataset
self.eval_metric = eval_metric
self.train_dataset = train_dataset
def __call__(self, eval_dataset, true_labels=None):
ue_args = self.ue_args
eval_metric = self.eval_metric
model = self.cls._auto_model
start = time.time()
log.info("******Perform stochastic inference...*******")
if ue_args.use_cache:
log.info("Caching enabled.")
model.enable_cache()
eval_results = {}
eval_results["sampled_probabilities"] = []
eval_results["sampled_answers"] = []
log.info("****************Start runs**************")
for i in tqdm(range(ue_args.committee_size)):
if ue_args.calibrate: # TODO: what is the purpose of calibration here?
self.cls.predict(self.calibration_dataset, calibrate=True)
log.info(f"Calibration temperature = {self.cls.temperature}")
preds, probs = self.cls.predict(eval_dataset)[:2]
eval_results["sampled_probabilities"].append(probs.tolist())
eval_results["sampled_answers"].append(preds.tolist())
if ue_args.eval_passes:
eval_score = eval_metric.compute(
predictions=preds, references=true_labels
)
log.info(f"Eval score: {eval_score}")
end = time.time()
log.info("**************Done.********************")
log.info(f"UE time: {end - start}")
eval_results["ue_time"] = end - start
return eval_results