-
Notifications
You must be signed in to change notification settings - Fork 8
/
reflow_distill.sh
113 lines (85 loc) · 3.14 KB
/
reflow_distill.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
#!/bin/bash
usage() {
echo "Usage: $0 pretrained_dir config generate_dir log_dir"
echo " pretrained_dir Path to the directory containing pretrained model"
echo " config Configuration file"
echo " generate_dir Directory to save generated files"
echo " log_dir Directory to save log files"
exit 1
}
split_file_list() {
local generate_dir=$1
local temp_file_list=$(mktemp)
find "$generate_dir" -type f -name "*.th" > "$temp_file_list"
head -n 1000 "$temp_file_list" > "$generate_dir/filelist.valid"
tail -n +1001 "$temp_file_list" > "$generate_dir/filelist.train"
rm "$temp_file_list"
}
if [ $# -ne 4 ]; then
echo "Error: Missing required arguments."
usage
fi
pretrained_dir=$1
config=$2
generate_dir=$3
log_dir=$4
num_pairs=10000
if [ ! -d "$pretrained_dir" ]; then
echo "Error: Pretrained directory '$pretrained_dir' does not exist."
exit 1
fi
if [ ! -f "$config" ]; then
echo "Error: Configuration file '$config' does not exist."
exit 1
fi
mkdir -p "$generate_dir"
mkdir -p "$log_dir"
if [ ! -d "$generate_dir" ]; then
echo "Error: Failed to create generate directory '$generate_dir'."
exit 1
fi
if [ ! -d "$log_dir" ]; then
echo "Error: Failed to create log directory '$log_dir'."
exit 1
fi
last_ckpt_path=$(find "$pretrained_dir" -name "last.ckpt" -print -quit)
if [ -z "$last_ckpt_path" ]; then
echo "Error: 'last.ckpt' not found in '$pretrained_dir' or its subdirectories."
exit 1
else
echo "Found 'last.ckpt' at: $last_ckpt_path"
fi
rf1_generate_dir="${generate_dir}/rf1"
mkdir -p "$rf1_generate_dir"
rf2_generate_dir="${generate_dir}/rf2"
mkdir -p "$rf2_generate_dir"
rf2_log_dir="${log_dir}/rf2"
mkdir -p "$rf2_log_dir"
rfd_log_dir="${log_dir}/rfd"
mkdir -p "$rfd_log_dir"
export PYTHONPATH=$(pwd):$PYTHONPATH
python3 reflow/generate_data.py --model_dir "$pretrained_dir" \
--save_dir "$rf1_generate_dir" \
--num_pairs $num_pairs
echo "generate rf1 data successfully."
split_file_list "$rf1_generate_dir"
python3 train.py --config "$config" \
--model.init_args.pretrained_ckpt_path "$last_ckpt_path" \
--trainer.logger.init_args.save_dir "$rf2_log_dir" \
--data.init_args.train_filelist "$rf1_generate_dir/filelist.train" \
--data.init_args.val_filelist "$rf1_generate_dir/filelist.valid"
# Use the data configuration from the pretrained model. Note that the data configuration of the
# ReFlow model is designed for paired data and is not compatible with generate_data.py.
python3 reflow/generate_data.py --model_dir "$rf2_log_dir" \
--save_dir "$rf2_generate_dir" \
--num_pairs $num_pairs \
--data_config "${pretrained_dir}/config.yaml"
echo "generate rf2 data successfully."
split_file_list "$rf2_generate_dir"
rf2_last_ckpt_path=$(find "$rf2_log_dir" -name "last.ckpt" -print -quit)
python3 train.py --config "$config" \
--model.init_args.one_step "true" \
--model.init_args.pretrained_ckpt_path "$rf2_last_ckpt_path" \
--trainer.logger.init_args.save_dir "$rfd_log_dir" \
--data.init_args.train_filelist "$rf2_generate_dir/filelist.train" \
--data.init_args.val_filelist "$rf2_generate_dir/filelist.valid"