Code of the QANet - A Quality Assurance Neural Network for Instance Segmentation, Assaf Arbelle, Eliav Elul, Michael Sidorov, Tammy Riklin Raviv (code by Michael Sidorov)
Execution Instructions:
- Build Train / Test Data:
Builds data in a form of list of tuples (image, segmentation, augmented segmentation, Jaccard(segmentation, augmented_segmentation))
Expected Input Data Format:
1) The files are expected to be located in two separate directories (i.e., one for images and one for corresponding segemtnations),
where the directory with the segmentations should have the same name as the directory with the images, but with 'MASK' postfix
(e.g., "01" with images and "01_MASK" with the corresponding masks).
2) Each image and the corresponding mask file should have "image0" and "mask0" prefixes respectively (e.g., "image01.tif" and
"mask01.tif", where the image prefix is "img0", and segmentation prefix is "mask0" etc.)**
EXAMPLE:
DATA ROOT -
01 -
image01.tif
image02.tif
image03.tif
...
01_MASK -
mask01.tif
mask02.tif
mask03.tif
...
02 -
image01.tif
image02.tif
image03.tif
...
02_MASK -
mask01.tif
mask02.tif
mask03.tif
...
...
- (*) Flags:
--input_data_dir (str) - Path to the directory where the input images and segmentations are located.
--output_data_dir (str) - Path to the directory where the output images and segmentations will be placed
--seg_dir_postfix (str) - The postfix of the directories with the segmentations (as described above)
--image_prefix (str) - The prefix of the images (as described above)
--seg_prefix (str) - The prefix of the segmentations (as described above)
--n_samples (int) - Number of samples to generate
--min_j (int) - The minimal jaccard to add to the samples
--max_j (int) - The maximal jaccard to add to the samples
--plot_samples (flag) - If to plot the generated samples (takes considerable amount of time)
- (**) To ensure your format fits the models' format (the format of the data dir may be changed at './configs/general_configs.py'),
in the beggining run the 'format_data_dir.py' with flags:
- current_seg_dir_postfix (str) - the postfix of the folder(s) with the segmentations
- current_image_prefix (str) - the prefix of the images
- current_seg_prefix (str) - the prefix of the segmentations
- Train / Test / Infer:
- Before running one of the train / test procedures - data must be build with the:
python build_data.py
- Run the desired procedure and implementation by executing:
python tf/torch_PROCEDURE.py (tf for the TensorFlow and torch for the PyTorch implementations respectively)
- Some of the flags(*) that may be used (find out more in the utils/aux_funcs.py, get_arg_parser() function):
- General:
--gpu_id (int) - Choose the GPU to run on (i.e., integers 0-N_GPUs, or -1 to run on the CPU)
Train:
--train_continue (flag) - If to continue training or to start a new training (if omitted)
--epochs (int) - How many epochs to train the net
--batch_size (int) - Configures the batch for the train procedure
--val_prop (float in range [0., 1.]) - Which proportion of the train data to use for validation
--lunch_tb (flag) - (tensor flow only) If to lunct the tensor board
--train_data_file (str) - The path to the file generated by the ./data_gen.py with the train data
--test_data_file (str) - The path to the file generated by the ./data_gen.py with the test data
--output_dir (str) - Where to place all the outputs. The outputs will be placed in a dir
in the following format: output_dir/PROCEDURE/MODEL_LIB_TIME_STAMP (e.g., ./outputs/train/pytorch_2022-08-07_13-18-23)
Test:
--tr_checkpoint_file (str) - (pytorch only) The path to the weights of the trained model
--tf_checkpoint_file (str) - (tensor flow only) The path to the weights of the trained model
--tf_checkpoint_dir (str) - (tensor flow only) The path to the weights of the trained model
--lunch_tb (flag) - (tensor flow only) If to lunct the tensor board
--test_data_file (str) - The path to the file generated by the ./data_gen.py with the test data
--output_dir (str) - Where to place all the outputs. The outputs will be placed in a dir
in the following format: output_dir/PROCEDURE/MODEL_LIB_TIME_STAMP (e.g., ./outputs/test/pytorch_2022-08-07_13-18-23)
Inference:
--tr_checkpoint_file (str) - (pytorch only) The path to the weights of the trained model
--tf_checkpoint_file (str) - (tensor flow only) The path to the weights of the trained model
--tf_checkpoint_dir (str) - (tensor flow only) The path to the weights of the trained model
--lunch_tb (flag) - (tensor flow only) If to lunct the tensor board
--inference_data_dir (str) - The path to the root directory which contains datafiles to infer in a format: inference_data_dir/images, inference_data_dir/segmentations
--output_dir (str) - Where to place all the outputs. The outputs will be placed in a dir
in the following format: output_dir/PROCEDURE/MODEL_LIB_TIME_STAMP (e.g., ./outputs/inference/pytorch_2022-08-07_13-18-23)
* All the flags has default configurations by the same name in ./configs/general_configs.py for the ease of use