Diabetic Retinopathy (DR) has become one of the leading causes of vision impairment in working-aged people and is a severe problem worldwide. However, most of the works ignored the ordinal information of labels. In this project, we propose a novel design MTCSNN, a Multi-task Clinical Siamese Neural Network for Diabetic Retinopathy severity prediction task. The novelty of this project is to utilize the ordinal information among labels and add a new regression task, which can help the model learn more discriminative feature embedding for fine-grained classification tasks. We perform comprehensive experiments over the RetinaMNIST, comparing MTCSNN with other models like ResNet-18, 34, 50. Our results indicate that MTCSNN outperforms the benchmark models in terms of AUC and accuracy on the test dataset.
L1 is the general cross-entropy loss employed in the classification task while L2 is the mean square error (MSE) loss targeting the difference regression task, which also acts as a form of regularization.- Model architecture implementation based on the code provided by torchvision.models.resnet:
resnet18.py
resnet34.py
resnet50.py
- Dataset from MedMNIST
dataset.py
: PyTorch datasets and dataloaders of MedMNISTevaluator.py
: Standardized evaluation functionsinfo.py
: Dataset informationdict
for each subset of MedMNIST