From 8f4b84ed8d22780b5276bbb1ec7eea2c77788c32 Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:10:16 +0800
Subject: [PATCH 01/11] Feature: Add .gitignore
---
.gitignore | 168 +++++++++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 168 insertions(+)
create mode 100644 .gitignore
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..dd954d7
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,168 @@
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# poetry
+# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
+# This is especially recommended for binary packages to ensure reproducibility, and is more
+# commonly ignored for libraries.
+# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
+#poetry.lock
+
+# pdm
+# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
+#pdm.lock
+# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
+# in version control.
+# https://pdm.fming.dev/#use-with-ide
+.pdm.toml
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+# PyCharm
+# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
+# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
+# and can be added to the global gitignore or merged into this file. For a more nuclear
+# option (not recommended) you can uncomment the following to ignore the entire idea folder.
+#.idea/
+
+# custom
+data/
+deploy/
+nuitka-crash-report.xml
+outputs/
+pretrained/
+
From 44c848eec01d1d105bce93f4a86704a584ec12f0 Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:10:35 +0800
Subject: [PATCH 02/11] Feature: Add requirements.txt
---
requirements.txt | 6 ++++++
1 file changed, 6 insertions(+)
create mode 100644 requirements.txt
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..b1feabe
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,6 @@
+torch # == 1.8.1
+segmentation-models-pytorch
+torchmetrics
+albumentations
+loguru
+tqdm
From 7b1d960b3e24c4d93ee763cfad650eabb5dfb0fa Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:13:24 +0800
Subject: [PATCH 03/11] Feature: Exporting ONNX models for DDRNet and STDC
---
models/ddrnet.py | 5 +++++
models/stdc.py | 5 +++++
2 files changed, 10 insertions(+)
diff --git a/models/ddrnet.py b/models/ddrnet.py
index 7184b63..e777dae 100644
--- a/models/ddrnet.py
+++ b/models/ddrnet.py
@@ -52,6 +52,11 @@ def forward(self, x, is_training=False):
x = self.seg_head(x)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
+ if torch.onnx.is_in_onnx_export():
+ # output_data = x.softmax(dim=1)
+ max_probs, predictions = x.max(1, keepdim=True)
+ return predictions.to(torch.int8)
+
if self.use_aux and is_training:
return x, (x_aux,)
else:
diff --git a/models/stdc.py b/models/stdc.py
index 259d9a8..45ab052 100644
--- a/models/stdc.py
+++ b/models/stdc.py
@@ -87,6 +87,11 @@ def forward(self, x, is_training=False):
x = self.seg_head(x)
x = F.interpolate(x, size, mode='bilinear', align_corners=True)
+ if torch.onnx.is_in_onnx_export():
+ # output_data = x.softmax(dim=1)
+ max_probs, predictions = x.max(1, keepdim=True)
+ return predictions.to(torch.int8)
+
if self.use_detail_head and is_training:
x_detail = self.detail_head(x3)
return x, x_detail
From bfc309826ded1081cd8f1deebe821057ae6bfa9a Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:15:09 +0800
Subject: [PATCH 04/11] Feature: Add ResizeToSquare transforms
---
utils/transforms.py | 38 +++++++++++++++++++++++++++++++++++++-
1 file changed, 37 insertions(+), 1 deletion(-)
diff --git a/utils/transforms.py b/utils/transforms.py
index f0db895..e22188c 100644
--- a/utils/transforms.py
+++ b/utils/transforms.py
@@ -1,5 +1,6 @@
import numpy as np
import albumentations as AT
+import torch.nn.functional as F
def to_numpy(array):
@@ -29,4 +30,39 @@ def __call__(self, image, mask=None):
augmented = aug(image=img)
else:
augmented = aug(image=img, mask=msk)
- return augmented
\ No newline at end of file
+ return augmented
+
+
+class ResizeToSquare:
+ def __init__(self, size, interpolation=1, p=1, is_testing=False):
+ self.size = size
+ self.interpolation = interpolation
+ self.p = p
+ self.is_testing = is_testing
+
+ def __call__(self, image, mask=None):
+ img = to_numpy(image)
+
+ h, w, _ = img.shape
+ max_wh = np.max([w, h])
+ hp = int((max_wh - w) / 2)
+ vp = int((max_wh - h) / 2)
+ padding = ((vp, vp), (hp, hp), (0, 0))
+ img = np.pad(img, padding, mode='constant', constant_values=0)
+
+ if not self.is_testing:
+ msk = to_numpy(mask)
+ msk_h, msk_w, _ = img.shape
+ msk_max_wh = np.max([msk_w, msk_h])
+ msk_hp = int((msk_max_wh - msk_w) / 2)
+ msk_vp = int((msk_max_wh - msk_h) / 2)
+ msk_padding = ((msk_vp, msk_vp), (msk_hp, msk_hp))
+ msk = np.pad(msk, msk_padding, mode='constant', constant_values=0)
+
+ aug = AT.Resize(height=self.size, width=self.size, interpolation=self.interpolation, p=self.p)
+
+ if self.is_testing:
+ augmented = aug(image=img)
+ else:
+ augmented = aug(image=img, mask=msk)
+ return augmented
From ffa5740a7efdf1234350327548efa9a649b7429e Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:16:05 +0800
Subject: [PATCH 05/11] Feature: Add support for importing BaseConfig
---
configs/__init__.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/configs/__init__.py b/configs/__init__.py
index c7dc363..9e2a940 100644
--- a/configs/__init__.py
+++ b/configs/__init__.py
@@ -1,2 +1,3 @@
+from .base_config import BaseConfig
from .my_config import MyConfig
from .parser import load_parser
\ No newline at end of file
From 13bc83b8670a34fe7fb0aab7b8b916d430b081bc Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:19:05 +0800
Subject: [PATCH 06/11] Feature: Add support for creating save_dir
---
core/seg_trainer.py | 6 +++++-
1 file changed, 5 insertions(+), 1 deletion(-)
diff --git a/core/seg_trainer.py b/core/seg_trainer.py
index 841a080..63f15d3 100644
--- a/core/seg_trainer.py
+++ b/core/seg_trainer.py
@@ -158,6 +158,10 @@ def predict(self, config):
self.logger.info('\nStart predicting...\n')
+ save_images_dir = os.path.join(config.save_dir, 'predicts')
+ if not os.path.exists(save_images_dir):
+ os.makedirs(save_images_dir)
+
self.model.eval() # Put model in evalation mode
for (images, images_aug, img_names) in tqdm(self.test_loader):
@@ -171,7 +175,7 @@ def predict(self, config):
# Saving results
for i in range(preds.shape[0]):
- save_path = os.path.join(config.save_dir, img_names[i])
+ save_path = os.path.join(save_images_dir, img_names[i])
save_suffix = img_names[i].split('.')[-1]
pred = Image.fromarray(preds[i].astype(np.uint8))
From 73cd07e1b8692bbdcfd0050dc0d39c5f4c38a229 Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Tue, 2 Apr 2024 17:21:40 +0800
Subject: [PATCH 07/11] Feature: Add support for custom datasets
---
configs/parser.py | 2 +-
datasets/__init__.py | 6 +++-
datasets/custom.py | 85 ++++++++++++++++++++++++++++++++++++++++++++
3 files changed, 91 insertions(+), 2 deletions(-)
create mode 100644 datasets/custom.py
diff --git a/configs/parser.py b/configs/parser.py
index da285c8..b51240f 100644
--- a/configs/parser.py
+++ b/configs/parser.py
@@ -16,7 +16,7 @@ def load_parser(config):
def get_parser():
parser = argparse.ArgumentParser()
# Dataset
- parser.add_argument('--dataset', type=str, default=None, choices=['cityscapes'],
+ parser.add_argument('--dataset', type=str, default=None, choices=['cityscapes', 'custom'],
help='choose which dataset you want to use')
parser.add_argument('--dataroot', type=str, default=None,
help='path to your dataset')
diff --git a/datasets/__init__.py b/datasets/__init__.py
index 5a1b07c..90bf60e 100644
--- a/datasets/__init__.py
+++ b/datasets/__init__.py
@@ -1,7 +1,11 @@
from torch.utils.data import DataLoader
from .cityscapes import Cityscapes
+from .custom import Custom
-dataset_hub = {'cityscapes':Cityscapes,}
+dataset_hub = {
+ 'cityscapes':Cityscapes,
+ 'custom': Custom,
+}
def get_dataset(config):
diff --git a/datasets/custom.py b/datasets/custom.py
new file mode 100644
index 0000000..8cd73ef
--- /dev/null
+++ b/datasets/custom.py
@@ -0,0 +1,85 @@
+import os
+from collections import namedtuple
+import yaml
+import numpy as np
+from PIL import Image
+from torch.utils.data import Dataset
+import albumentations as AT
+from albumentations.pytorch import ToTensorV2
+from utils import transforms
+
+
+class Custom(Dataset):
+ '''
+ demo for load custom datasets.
+ '''
+
+ def __init__(self, config, mode='train'):
+ data_root = os.path.expanduser(config.data_root)
+ dataset_config_filepath = os.path.join(data_root, 'data.yaml')
+ if not os.path.exists(dataset_config_filepath):
+ raise Exception(f"{dataset_config_filepath} not exists.")
+ with open(dataset_config_filepath, 'r', encoding='utf-8') as yaml_file:
+ dataset_config = yaml.safe_load(yaml_file)
+ print('dataset_config: ', dataset_config)
+ data_root = dataset_config['path']
+ # self.num_classes = len(dataset_config['names'])
+ self.id_to_train_id = dict()
+ for i in range(len(dataset_config['names'])):
+ self.id_to_train_id[i] = i
+ # self.train_id_to_name = dict()
+ # for k, v in dataset_config['names'].items():
+ # self.train_id_to_name[k] = str(v)
+
+ img_dir = os.path.join(data_root, mode, 'imgs')
+ msk_dir = os.path.join(data_root, mode, 'masks')
+
+ if not os.path.isdir(img_dir):
+ raise RuntimeError(f'Image directory: {img_dir} does not exist.')
+
+ if not os.path.isdir(msk_dir):
+ raise RuntimeError(f'Mask directory: {msk_dir} does not exist.')
+
+ if mode == 'train':
+ self.transform = AT.Compose([
+ transforms.ResizeToSquare(size=config.train_size),
+ transforms.Scale(scale=config.scale),
+ AT.RandomScale(scale_limit=config.randscale),
+ AT.PadIfNeeded(min_height=config.crop_h, min_width=config.crop_w, value=(114,114,114), mask_value=(0,0,0)),
+ AT.RandomCrop(height=config.crop_h, width=config.crop_w),
+ AT.ColorJitter(brightness=config.brightness, contrast=config.contrast, saturation=config.saturation),
+ AT.HorizontalFlip(p=config.h_flip),
+ AT.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),
+ ToTensorV2(),
+ ])
+
+ elif mode == 'val':
+ self.transform = AT.Compose([
+ transforms.ResizeToSquare(size=config.test_size),
+ transforms.Scale(scale=config.scale),
+ AT.Normalize(mean=(0.0, 0.0, 0.0), std=(1.0, 1.0, 1.0)),
+ ToTensorV2(),
+ ])
+
+ self.images = []
+ self.masks = []
+
+ for img_file_name in os.listdir(img_dir):
+ img_file_basename = os.path.splitext(img_file_name)[0]
+
+ self.images.append(os.path.join(img_dir, img_file_name))
+ self.masks.append(os.path.join(msk_dir, img_file_basename + '.png'))
+
+ def __len__(self):
+ return len(self.images)
+
+ def __getitem__(self, index):
+ image = np.asarray(Image.open(self.images[index]).convert('RGB'))
+ mask = np.asarray(Image.open(self.masks[index]).convert('L'))
+
+ # Perform augmentation and normalization
+ augmented = self.transform(image=image, mask=mask)
+ image, mask = augmented['image'], augmented['mask']
+
+ return image, mask
+
\ No newline at end of file
From cf383aa40f0e6020fab3b4a064cd274e597e7a9d Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Wed, 3 Apr 2024 14:05:35 +0800
Subject: [PATCH 08/11] update requirements.txt
---
requirements.txt | 4 ++++
1 file changed, 4 insertions(+)
diff --git a/requirements.txt b/requirements.txt
index b1feabe..051e7f2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -4,3 +4,7 @@ torchmetrics
albumentations
loguru
tqdm
+tensorboard
+onnx
+onnxsim
+onnxruntime
From 68821dfedbd36bc8542123b0b9963c08abee7413 Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Wed, 3 Apr 2024 14:12:51 +0800
Subject: [PATCH 09/11] Fix: Segformer backbone compatibility with mit_b*
encoders
---
models/__init__.py | 10 +++++++++-
1 file changed, 9 insertions(+), 1 deletion(-)
diff --git a/models/__init__.py b/models/__init__.py
index daf6056..8cd38c9 100644
--- a/models/__init__.py
+++ b/models/__init__.py
@@ -68,7 +68,15 @@ def get_model(config):
if config.decoder not in decoder_hub:
raise ValueError(f"Unsupported decoder type: {config.decoder}")
- model = decoder_hub[config.decoder](encoder_name=config.encoder,
+ if config.encoder.startswith('mit_b') and config.decoder in ['pan']:
+ model = decoder_hub[config.decoder](encoder_name=config.encoder,
+ encoder_weights=config.encoder_weights,
+ encoder_output_stride=32,
+ in_channels=3, classes=config.num_class)
+ elif config.encoder.startswith('mit_b') and config.decoder in ['deeplabv3', 'deeplabv3p', 'linknet', 'unetpp']:
+ raise ValueError("Encoder `{}` is not supported for `{}".format(config.encoder, config.decoder))
+ else:
+ model = decoder_hub[config.decoder](encoder_name=config.encoder,
encoder_weights=config.encoder_weights,
in_channels=3, classes=config.num_class)
From 154d4db5013e51dc79e4e90bf2708aaf1f2195f8 Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Wed, 3 Apr 2024 14:15:49 +0800
Subject: [PATCH 10/11] Feature: Add script to convert Labelme labels to custom
dataset format
---
requirements.txt | 1 +
utils/check_datasets.py | 112 ++++++++++++++++++++++++++++++++++++++++
2 files changed, 113 insertions(+)
create mode 100644 utils/check_datasets.py
diff --git a/requirements.txt b/requirements.txt
index 051e7f2..6ea7569 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,4 @@ tensorboard
onnx
onnxsim
onnxruntime
+labelme
diff --git a/utils/check_datasets.py b/utils/check_datasets.py
new file mode 100644
index 0000000..529709a
--- /dev/null
+++ b/utils/check_datasets.py
@@ -0,0 +1,112 @@
+import os
+import shutil
+import argparse
+import random
+random.seed(0)
+
+import cv2
+from tqdm import tqdm
+import numpy as np
+import labelme
+from PIL import Image
+
+
+def check_semantic_segmentation_datasets(datasets_path):
+ out_path = datasets_path
+ labels_path = os.path.join(out_path, 'labels')
+ if not os.path.exists(labels_path):
+ print('Error: %s not found' % (labels_path))
+ return
+ datasets_root = os.path.join(datasets_path, 'out')
+
+ datasets_train = os.path.join(datasets_root, 'train')
+ datasets_val = os.path.join(datasets_root, 'val')
+
+ datasets_train_imgs = os.path.join(datasets_train, 'imgs')
+ datasets_train_masks = os.path.join(datasets_train, 'masks')
+ datasets_val_imgs = os.path.join(datasets_val, 'imgs')
+ datasets_val_masks = os.path.join(datasets_val, 'masks')
+
+ if os.path.exists(datasets_root):
+ shutil.rmtree(datasets_root)
+
+ for d in [datasets_root, datasets_train, datasets_val, datasets_train_imgs, datasets_train_masks, datasets_val_imgs, datasets_val_masks]:
+ if not os.path.exists(d):
+ os.makedirs(d)
+
+ all_data = [i for i in os.listdir(labels_path) if os.path.splitext(i)[1] == '.json']
+ print('all_data: ', len(all_data))
+
+ print(all_data[:5])
+ random.shuffle(all_data)
+ print(all_data[:5])
+
+ train_factor = 0.95
+ train_num = round(train_factor * len(all_data))
+
+ class_name_to_id = {'_background': 0}
+
+ for i in tqdm(all_data):
+ filename = os.path.splitext(os.path.basename(i))[0]
+ label_file = labelme.LabelFile(filename=os.path.join(labels_path, i))
+ for shape in label_file.shapes:
+ if shape.get('shape_type', '') == 'polygon':
+ if shape.get('label', 'None') not in class_name_to_id.keys():
+ class_name_to_id[shape.get('label', 'None')] = len(class_name_to_id)
+ print(class_name_to_id)
+
+ for i in tqdm(all_data[:train_num]):
+ filename = os.path.splitext(os.path.basename(i))[0]
+ label_file = labelme.LabelFile(filename=os.path.join(labels_path, i))
+ img = labelme.utils.img_data_to_arr(label_file.imageData)
+ if img.ndim == 3:
+ img = img[:, :, ::-1]
+ lbl, _ = labelme.utils.shapes_to_label(
+ img_shape=img.shape,
+ shapes=label_file.shapes,
+ label_name_to_value=class_name_to_id,
+ )
+ #cv2.imwrite(os.path.join(imgages_path, filename+'.png'), img)
+ cv2.imencode('.png', img)[1].tofile(os.path.join(datasets_train_imgs, filename+'.png'))
+ #cv2.imwrite(os.path.join(masks_path, filename+'.png'), lbl)
+ cv2.imencode('.png', lbl)[1].tofile(os.path.join(datasets_train_masks, filename+'.png'))
+
+ for i in tqdm(all_data[train_num:]):
+ filename = os.path.splitext(os.path.basename(i))[0]
+ label_file = labelme.LabelFile(filename=os.path.join(labels_path, i))
+ img = labelme.utils.img_data_to_arr(label_file.imageData)
+ if img.ndim == 3:
+ img = img[:, :, ::-1]
+ lbl, _ = labelme.utils.shapes_to_label(
+ img_shape=img.shape,
+ shapes=label_file.shapes,
+ label_name_to_value=class_name_to_id,
+ )
+ #cv2.imwrite(os.path.join(imgages_path, filename+'.png'), img)
+ cv2.imencode('.png', img)[1].tofile(os.path.join(datasets_val_imgs, filename+'.png'))
+ #cv2.imwrite(os.path.join(masks_path, filename+'.png'), lbl)
+ cv2.imencode('.png', lbl)[1].tofile(os.path.join(datasets_val_masks, filename+'.png'))
+
+ with open(os.path.join(datasets_root, 'data.yaml'), 'w+', encoding='UTF-8') as yaml_file:
+ yaml_file.write('path: ')
+ yaml_file.write(os.path.abspath(datasets_root))
+ yaml_file.write('\n')
+
+ yaml_file.write('names: \n')
+ for i in range(len(class_name_to_id)):
+ for k, v in class_name_to_id.items():
+ if v == i:
+ yaml_file.write(' %d: %s\n' % (i, k))
+
+
+def parse_opt():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--datasets_root', type=str, default="", help="path to datasets root dir.")
+
+ opt = parser.parse_args()
+ return opt
+
+
+if __name__ == '__main__':
+ opt = parse_opt()
+ check_semantic_segmentation_datasets(opt.datasets_root)
\ No newline at end of file
From 2470fa3a7118bd63b81f13db3bb6d0f15911565a Mon Sep 17 00:00:00 2001
From: acai66 <1779864536@qq.com>
Date: Wed, 3 Apr 2024 14:18:40 +0800
Subject: [PATCH 11/11] Fix: fix typo
---
README.md | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/README.md b/README.md
index 4c31fa5..60c7636 100644
--- a/README.md
+++ b/README.md
@@ -196,11 +196,11 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 python main.py
## Knowledge distillation
-| Model | Encoder | Decoder | kd_training | mIoU(200 epoch) | mIoU(800 epoch) |
-|:-----:|:-------------:|:-----------------------:|:-----------:|:---------------:|:---------------:|
-| SMP | DeepLabv3Plus | ResNet-101
teacher | - | 78.10 | 79.20 |
-| SMP | DeepLabv3Plus | ResNet-18
student | False | 73.97 | 75.90 |
-| SMP | DeepLabv3Plus | ResNet-18
student | True | 75.20 | 76.41 |
+| Model | Encoder | Decoder | kd_training | mIoU(200 epoch) | mIoU(800 epoch) |
+|:-----:|:-----------------------:|:-------------:|:-----------:|:---------------:|:---------------:|
+| SMP | ResNet-101
teacher | DeepLabv3Plus | - | 78.10 | 79.20 |
+| SMP | ResNet-18
student | DeepLabv3Plus | False | 73.97 | 75.90 |
+| SMP | ResNet-18
student | DeepLabv3Plus | True | 75.20 | 76.41 |
# Prepare the dataset