-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
120 lines (111 loc) · 4.97 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from utils import *
from train import *
from _test import *
import torch
import torch.nn as nn
import os
import wandb
import numpy as np
def multi_eval(model, testloaders, log, args):
avg_accs = []
worst_accs = []
for i, testloader in enumerate(testloaders):
print (f'Set{i}:')
acc, worst_acc = test_cnn(testloader, model, log=log, args=args, inferred_groups=True)
avg_accs.append(acc)
worst_accs.append(worst_acc)
print (worst_accs)
return np.mean(avg_accs), np.median(worst_accs)
def run_last_layer_experiment(model, device, balanced_dataloader, testloaders, exp_name,
optimizer, l1_lambda, scheduler, dataset='waterbirds',
epochs=30, log=False, inspect_loader=None, seed=1, args=None):
curr_worst = 0
curr_avg = 0
if log:
run = wandb.init(project=exp_name,
entity='username',
config={
"learning_rate": optimizer.state_dict()['param_groups'][0]['initial_lr'],
"epochs": epochs,
"step_size": scheduler.step_size,
"gamma": scheduler.gamma
},
reinit=True
)
cnn_optimizer = optimizer
lr = cnn_optimizer.state_dict()['param_groups'][0]['initial_lr']
cnn_scheduler = scheduler
global_step = 0
saved_model = None
best_model = None
save_dir = os.path.join(args.output_path,
f"{args.experiment}_{args.comments}_{args.dataset}_LR{args.learning_rate}_step{args.step_size}_gamma{args.gamma}_seed{args.seed}_samples{args.sample_size}_l1{args.l1}/")
for epoch in range(epochs):
try:
print("=========================")
print("epoch:", epoch)
print("=========================")
global_step = train_cnn(balanced_dataloader, model, cnn_optimizer, cnn_scheduler, global_step, device, l1_lambda, log)
print('----> [Val/Test]')
with torch.no_grad():
inv_acc, worst_acc = multi_eval(model, testloaders, log, args)
if inv_acc > curr_avg:
curr_avg = inv_acc
if best_model:
os.remove(best_model)
with torch.no_grad():
best_model = os.path.join(save_dir,f"best_avg_epoch{epoch}.model")
torch.save(model.state_dict(), best_model)
if worst_acc == curr_worst and inv_acc > curr_avg:
if saved_model:
os.remove(saved_model)
with torch.no_grad():
saved_model = os.path.join(save_dir,f"best_worst_epoch{epoch}.model")
torch.save(model.state_dict(), saved_model)
if worst_acc > curr_worst:
curr_worst = worst_acc
if saved_model:
os.remove(saved_model)
with torch.no_grad():
saved_model = os.path.join(save_dir,f"best_worst_epoch{epoch}.model")
torch.save(model.state_dict(), saved_model)
if log:
wandb.log({"Test Mean Accuracy": inv_acc})
except KeyboardInterrupt:
print('Experiment Stopped')
break
last_model = os.path.join(save_dir,"last.model")
torch.save(model.state_dict(), last_model)
print(f'last model saved at {last_model}')
return saved_model
def run_loss_inspect_experiment(model, device, spuriousity, balanced_dataloader, testloader, exp_name,
optimizer, l1_lambda, scheduler, dataset='waterbirds',
epochs=30, log=False, inspect_loader=None):
losses = [[] for i in range(len(inspect_loader.dataset))]
loss_fn = nn.CrossEntropyLoss()
curr_worst = 0
curr_avg = 0
cnn_optimizer = optimizer
lr = cnn_optimizer.state_dict()['param_groups'][0]['initial_lr']
cnn_scheduler = scheduler
step_size, gamma = cnn_scheduler.step_size, cnn_scheduler.gamma
global_step = 0
for epoch in range(epochs):
try:
print("=========================")
print("epoch:", epoch)
print("=========================")
global_step = train_cnn(balanced_dataloader, model, cnn_optimizer, cnn_scheduler, global_step, device, l1_lambda, log)
print("====== Calculating losses on inspect samples ======")
with torch.no_grad():
for (i, (x, y, g)) in enumerate(inspect_loader):
y_pred = model(x.cuda())
loss = loss_fn(y_pred, y.cuda())
losses[i].append(loss.cpu().item())
except KeyboardInterrupt:
print('Experiment Stopped')
break
last_model = f'last_{exp_name}_{dataset}_sp{spuriousity}_LR{lr}_epoch{global_step}_step{step_size}_gamma{gamma}.model'
torch.save(model.state_dict(), last_model)
print(f'last model saved at {last_model}')
return model, losses