-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathanalyze.py
60 lines (49 loc) · 1.46 KB
/
analyze.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import scienceplots
from util import (
select_project,
select_group,
select_seed,
select_device,
load_model,
load_data,
load_study,
load_best_model,
)
def test_model(model, dl_val, device):
model.eval()
total_loss = 0
all_preds = []
all_targets = []
with torch.no_grad():
for x, y in dl_val:
x, y = x.to(device), y.to(device)
y_pred = model(x)
loss = F.mse_loss(y_pred, y)
total_loss += loss.item()
all_preds.extend(y_pred.cpu().numpy())
all_targets.extend(y.cpu().numpy())
return total_loss / len(dl_val), all_preds, all_targets
def main():
# Test run
project = select_project()
# group_name = select_group(project)
# seed = select_seed(project, group_name)
# device = select_device()
# model, config = load_model(project, group_name, seed)
# model = model.to(device)
# Load the best model
study_name = "Optimize_Template"
model, config = load_best_model(project, study_name)
device = select_device()
model = model.to(device)
_, dl_val = load_data() # Assuming this is implemented in util.py
val_loss, preds, targets = test_model(model, dl_val, device)
print(f"Validation Loss: {val_loss}")
# Additional custom analysis can be added here
# ...
if __name__ == "__main__":
main()