-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate.py
80 lines (65 loc) · 2.67 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import time
import hydra
import wandb
from hydra.utils import instantiate
from omegaconf import OmegaConf
from tqdm import tqdm
import torch
import torch.nn as nn
import quest.utils.utils as utils
from pyinstrument import Profiler
from moviepy.editor import ImageSequenceClip
import json
OmegaConf.register_new_resolver("eval", eval, replace=True)
@hydra.main(config_path="config", config_name='evaluate', version_base=None)
def main(cfg):
device = cfg.device
seed = cfg.seed
torch.manual_seed(seed)
train_cfg = cfg.training
OmegaConf.resolve(cfg)
# create model
save_dir, _ = utils.get_experiment_dir(cfg, evaluate=True)
os.makedirs(save_dir)
if cfg.checkpoint_path is None:
# Basically if you don't provide a checkpoint path it will automatically find one corresponding
# to the experiment/variant name you provide
checkpoint_path, _ = utils.get_experiment_dir(cfg, evaluate=False, allow_overlap=True)
checkpoint_path = utils.get_latest_checkpoint(checkpoint_path)
else:
checkpoint_path = utils.get_latest_checkpoint(cfg.checkpoint_path)
state_dict = utils.load_state(checkpoint_path)
if 'config' in state_dict:
print('autoloading based on saved parameters')
model = instantiate(state_dict['config']['algo']['policy'],
shape_meta=cfg.task.shape_meta)
else:
model = instantiate(cfg.algo.policy,
shape_meta=cfg.task.shape_meta)
model.to(device)
model.eval()
model.load_state_dict(state_dict['model'])
env_runner = instantiate(cfg.task.env_runner)
print('Saving to:', save_dir)
print('Running evaluation...')
def save_video_fn(video_chw, env_name, idx):
video_dir = os.path.join(save_dir, 'videos', env_name)
os.makedirs(video_dir, exist_ok=True)
save_path = os.path.join(video_dir, f'{idx}.mp4')
clip = ImageSequenceClip(list(video_chw.transpose(0, 2, 3, 1)), fps=24)
clip.write_videofile(save_path, fps=24, verbose=False, logger=None)
if train_cfg.do_profile:
profiler = Profiler()
profiler.start()
rollout_results = env_runner.run(model, n_video=cfg.rollout.n_video, do_tqdm=train_cfg.use_tqdm, save_video_fn=save_video_fn)
if train_cfg.do_profile:
profiler.stop()
profiler.print()
print(
f"[info] success rate: {rollout_results['rollout']['overall_success_rate']:1.3f} \
| environments solved: {rollout_results['rollout']['environments_solved']}")
with open(os.path.join(save_dir, 'data.json'), 'w') as f:
json.dump(rollout_results, f)
if __name__ == "__main__":
main()