-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtemplates_ensemble.py
61 lines (52 loc) · 3.12 KB
/
templates_ensemble.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import itertools
import os
import warnings
from collections import defaultdict
from evaluate import evaluate_setup
from models import load_generator
from utils import parse_args, get_results_torch, save_results_torch
from templates import get_templates
if __name__ == "__main__":
args = parse_args()
for model in args.models:
generator = load_generator(model, cache_dir=args.cache_dir, precision=args.precision,
local_files_only=args.local_files_only, device_map=args.device_map,
)
for dataset, seed, prediction_method, selection_method in itertools.product(
args.dataset, args.seed, args.prediction_method, args.examples_selection_method):
if selection_method == '0-shot':
num_shots_range = [0]
else:
num_shots_range = args.num_shots
if prediction_method in ["channel", "calibrate"]:
if not args.labels_loss:
warnings.warn(f"Using {prediction_method} with labels_loss set to False is highly discouraged, "
f"setting to True.")
labels_loss = True
else:
labels_loss = args.labels_loss
method_name = f"{prediction_method}_{labels_loss}"
for num_shots in num_shots_range:
templates = get_templates(dataset, num_shots, args.num_templates, args.templates_path, seed)
# skip already computed scores
save_dir = os.path.join(args.save_dir, dataset)
name = f"{num_shots}_shot_ensembles.out"
results = get_results_torch(save_dir=save_dir, name=name)
if model not in results:
results[model] = defaultdict(dict)
if method_name not in results[model]:
results[model][method_name] = defaultdict(dict)
if seed not in results[model][method_name]:
results[model][method_name][seed] = defaultdict(list)
num_evaluated_templates = len(results[model][method_name][seed]["scores"])
for template in templates[num_evaluated_templates:]:
evaluation_result = evaluate_setup(dataset=dataset, generator=generator, seed=seed,
template=template,
num_shots=num_shots, selection_method=selection_method,
example_ids=args.example_ids, examples_path=args.examples_path,
prediction_method=prediction_method, labels_loss=labels_loss,
batch_size=args.eval_batch_size, cache_dir=args.cache_dir,
)
for key in ["scores", "predicts", "probs"]:
results[model][method_name][seed][key].append(evaluation_result[key])
save_results_torch(res_obj=results, name=name, save_dir=save_dir)