diff --git a/Classification/cnns/config.py b/Classification/cnns/config.py index 215a385..4245269 100755 --- a/Classification/cnns/config.py +++ b/Classification/cnns/config.py @@ -25,6 +25,9 @@ def get_parser(parser=None): def str_list(x): return x.split(',') + + def str1_list(x): + return x.split('.') def int_list(x): return list(map(int, x.split(','))) @@ -53,7 +56,8 @@ def str2bool(v): parser.add_argument('--node_ips', type=str_list, default=['192.168.1.13', '192.168.1.14'], help='nodes ip list for training, devided by ",", length >= num_nodes') parser.add_argument("--ctrl_port", type=int, default=50051, help='ctrl_port for multinode job') - + parser.add_argument('--ssp_placement', type=str1_list, default=[], + help='stage partition strategy list for placement, devided by ".", stage is 2 : "0:0-7,1:0-6"|"1-7') parser.add_argument("--model", type=str, default="resnet50", help="resnet50") parser.add_argument( diff --git a/Classification/cnns/job_function_util.py b/Classification/cnns/job_function_util.py index 88675c0..92d7901 100755 --- a/Classification/cnns/job_function_util.py +++ b/Classification/cnns/job_function_util.py @@ -40,6 +40,13 @@ def get_train_config(args): else: train_config.cudnn_conv_heuristic_search_algo(False) train_config.enable_fuse_model_update_ops(True) + + if args.ssp_placement: + flow.env.init() + train_config.ssp_placement( + *[flow.stage(flow.scope.placement("gpu", device_name)) for device_name in args.ssp_placement] + ) + return train_config diff --git a/Classification/cnns/of_ssp_cnn_train_val.py b/Classification/cnns/of_ssp_cnn_train_val.py new file mode 100755 index 0000000..37da328 --- /dev/null +++ b/Classification/cnns/of_ssp_cnn_train_val.py @@ -0,0 +1,116 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import os +import math +import oneflow as flow +import ofrecord_util +import optimizer_util +import config as configs +from util import Snapshot, Summary, InitNodes, Metric +from job_function_util import get_train_config, get_val_config +import resnet_model +import resnext_model +import vgg_model +import alexnet_model +import inception_model +import mobilenet_v2_model + +parser = configs.get_parser() +args = parser.parse_args() +configs.print_args(args) + +total_device_num = args.num_nodes * args.gpu_num_per_node +train_batch_size = total_device_num * args.batch_size_per_device +val_batch_size = total_device_num * args.val_batch_size_per_device +(C, H, W) = args.image_shape +epoch_size = math.ceil(args.num_examples / train_batch_size) +num_val_steps = int(args.num_val_examples / val_batch_size) + + +model_dict = { + "resnet50": resnet_model.resnet50, + "vgg": vgg_model.vgg16bn, + "alexnet": alexnet_model.alexnet, + "inceptionv3": inception_model.inceptionv3, + "mobilenetv2": mobilenet_v2_model.Mobilenet, + "resnext50": resnext_model.resnext50, +} + + +flow.config.gpu_device_num(args.gpu_num_per_node) +flow.config.enable_debug_mode(True) + +if args.use_fp16 and args.num_nodes * args.gpu_num_per_node > 1: + flow.config.collective_boxing.nccl_fusion_all_reduce_use_buffer(False) + +if args.nccl_fusion_threshold_mb: + flow.config.collective_boxing.nccl_fusion_threshold_mb(args.nccl_fusion_threshold_mb) + +if args.nccl_fusion_max_ops: + flow.config.collective_boxing.nccl_fusion_max_ops(args.nccl_fusion_max_ops) + +def label_smoothing(labels, classes, eta, dtype): + assert classes > 0 + assert eta >= 0.0 and eta < 1.0 + return flow.one_hot(labels, depth=classes, dtype=dtype, + on_value=1 - eta + eta / classes, off_value=eta/classes) + + +@flow.global_function("train", get_train_config(args)) +def TrainNet(): + if args.train_data_dir: + assert os.path.exists(args.train_data_dir) + print("Loading data from {}".format(args.train_data_dir)) + (labels, images) = ofrecord_util.load_imagenet_for_training(args) + + else: + print("Loading synthetic data.") + (labels, images) = ofrecord_util.load_synthetic(args) + logits = model_dict[args.model](images, args) + if args.label_smoothing > 0: + one_hot_labels = label_smoothing(labels, args.num_classes, args.label_smoothing, logits.dtype) + loss = flow.nn.softmax_cross_entropy_with_logits(one_hot_labels, logits, name="softmax_loss") + else: + loss = flow.nn.sparse_softmax_cross_entropy_with_logits(labels, logits, name="softmax_loss") + + loss = flow.math.reduce_mean(loss) + predictions = flow.nn.softmax(logits) + outputs = {"loss": loss, "predictions": predictions, "labels": labels} + + # set up warmup,learning rate and optimizer + optimizer_util.set_up_optimizer(loss, args) + return outputs + + +def main(): + InitNodes(args) + flow.env.log_dir(args.log_dir) + + summary = Summary(args.log_dir, args) + snapshot = Snapshot(args.model_save_dir, args.model_load_dir) + + for epoch in range(args.num_epochs): + metric = Metric(desc='train', calculate_batches=args.loss_print_every_n_iter, + summary=summary, save_summary_steps=epoch_size, + batch_size=train_batch_size, loss_key='loss') + for i in range(epoch_size): + TrainNet().async_get(metric.metric_cb(epoch, i)) + + snapshot.save('epoch_{}'.format(epoch)) + + +if __name__ == "__main__": + main() diff --git a/Classification/cnns/train_fp32_ssp.sh b/Classification/cnns/train_fp32_ssp.sh new file mode 100755 index 0000000..4b382d0 --- /dev/null +++ b/Classification/cnns/train_fp32_ssp.sh @@ -0,0 +1,59 @@ +rm -rf core.* +rm -rf ./output/snapshots/* + +if [ -n "$1" ]; then + NUM_EPOCH=$1 +else + NUM_EPOCH=50 +fi +echo NUM_EPOCH=$NUM_EPOCH + +# training with imagenet +if [ -n "$2" ]; then + DATA_ROOT=$2 +else + DATA_ROOT=/data/imagenet/ofrecord +fi +echo DATA_ROOT=$DATA_ROOT + +BATCH_SIZE=${3:-""} +echo BATCH_SIZE=$BATCH_SIZE + +SSP_PLACEMENT=${4:-""} +echo SSP_PLACEMENT=$SSP_PLACEMENT + +MODEL_NAME=${5:-"alexnet"} +echo MODEL_NAME=$MODEL_NAME + +LOG_FOLDER=../logs +mkdir -p $LOG_FOLDER +LOGFILE=$LOG_FOLDER/resnet_training.log + +export PYTHONUNBUFFERED=1 +echo PYTHONUNBUFFERED=$PYTHONUNBUFFERED +export NCCL_LAUNCH_MODE=PARALLEL +echo NCCL_LAUNCH_MODE=$NCCL_LAUNCH_MODE + +python3 of_ssp_cnn_train_val.py \ + --train_data_dir=$DATA_ROOT/train \ + --train_data_part_num=256 \ + --num_nodes=2 \ + --gpu_num_per_node=8 \ + --ssp_placement=$SSP_PLACEMENT \ + --optimizer="sgd" \ + --momentum=0.875 \ + --label_smoothing=0.1 \ + --learning_rate=0.768 \ + --loss_print_every_n_iter=100 \ + --batch_size_per_device=$BATCH_SIZE \ + --val_batch_size_per_device=50 \ + --channel_last=False \ + --fuse_bn_relu=True \ + --fuse_bn_add_relu=True \ + --nccl_fusion_threshold_mb=16 \ + --nccl_fusion_max_ops=24 \ + --gpu_image_decoder=True \ + --num_epoch=$NUM_EPOCH \ + --model=$MODEL_NAME 2>&1 | tee ${LOGFILE} + +echo "Writting log to ${LOGFILE}"