forked from CV4EcologySchool/ct_classifier
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
=
committed
Aug 3, 2022
0 parents
commit 4782c7e
Showing
11 changed files
with
580 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
datasets | ||
model_states | ||
|
||
__pycache__ | ||
*.pyc | ||
|
||
.vscode | ||
.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# Here's where you define experiment-specific hyperparameters. | ||
# You can also create lists and group parameters together into nested sub-parts. | ||
# In Python, this is all read as a dict. | ||
|
||
# environment/computational parameters | ||
device: CUDA | ||
num_workers: 4 | ||
|
||
# dataset parameters | ||
data_root: /path/to/dataset | ||
num_classes: 32 | ||
|
||
# training hyperparameters | ||
num_epochs: 200 | ||
batch_size: 128 | ||
learning_rate: 1e-3 | ||
weight_decay: 1e-3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
''' | ||
2022 Benjamin Kellenberger | ||
''' | ||
|
||
# Info: these "__init__.py" files go into every folder and subfolder that | ||
# contains Python code. It is required for Python to find all the scripts you | ||
# created and for you to be able to import them. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
''' | ||
PyTorch dataset class for COCO-CT-formatted datasets. Note that you could | ||
use the official PyTorch MS-COCO wrappers: | ||
https://pytorch.org/vision/master/generated/torchvision.datasets.CocoDetection.html | ||
We just hack our way through the COCO JSON files here for demonstration | ||
purposes. | ||
See also the MS-COCO format on the official Web page: | ||
https://cocodataset.org/#format-data | ||
2022 Benjamin Kellenberger | ||
''' | ||
|
||
import os | ||
import json | ||
from torch.utils.data import Dataset | ||
from torchvision.transforms import ToTensor | ||
from PIL import Image | ||
|
||
|
||
class CTDataset(Dataset): | ||
|
||
def __init__(self, data_root, split='train'): | ||
''' | ||
Constructor. Here, we collect and index the dataset inputs and | ||
labels. | ||
''' | ||
self.data_root = data_root | ||
self.split = split | ||
self.transform = ToTensor() | ||
|
||
# index data into list | ||
self.data = [] | ||
|
||
# load annotation file | ||
annoPath = os.path.join( | ||
self.data_root, | ||
'eccv_18_annotation_files', | ||
'train_annotations.json' if self.split=='train' else 'cis_val_annotations.json' | ||
) | ||
meta = json.load(open(annoPath, 'r')) | ||
|
||
|
||
def __len__(self): | ||
''' | ||
Returns the length of the dataset. | ||
''' | ||
return len(self.data) | ||
|
||
|
||
def __getitem__(self, idx): | ||
''' | ||
Returns a single data point at given idx. | ||
Here's where we actually load the image. | ||
''' | ||
image_name, label = self.data[idx] | ||
|
||
# load image | ||
image_path = os.path.join(self.data_root, image_path) | ||
img = Image.open(image_path) | ||
|
||
# transform: convert to torch.Tensor | ||
# here's where we could do data augmentation: | ||
# https://pytorch.org/vision/stable/transforms.html | ||
# see Björn's lecture on Thursday, August 11. | ||
# For now, we only convert the image to torch.Tensor | ||
img_tensor = self.transform(img) | ||
|
||
return img_tensor, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
''' | ||
Model implementation. | ||
We'll be using a "simple" ResNet-18 for image classification here. | ||
2022 Benjamin Kellenberger | ||
''' | ||
|
||
import torch.nn as nn | ||
from torchvision.models import resnet | ||
|
||
|
||
class CustomResNet18(nn.Module): | ||
|
||
def __init__(self, num_classes): | ||
''' | ||
Constructor of the model. Here, we initialize the model's | ||
architecture (layers). | ||
''' | ||
super(CustomResNet18, self).__init__() | ||
|
||
self.feature_extractor = resnet.resnet18(pretrained=True) # "pretrained": use weights pre-trained on ImageNet | ||
|
||
# replace the very last layer from the original, 1000-class output | ||
# ImageNet to a new one that outputs num_classes | ||
last_layer = self.feature_extractor.fc # tip: print(self.feature_extractor) to get info on how model is set up | ||
num_features = last_layer.num_features | ||
self.feature_extractor.fc = nn.Identity() # discard last layer... | ||
|
||
self.classifier = nn.Linear(num_features, num_classes) # ...and create a new one | ||
|
||
|
||
def forward(self, x): | ||
''' | ||
Forward pass. Here, we define how to apply our model. It's basically | ||
applying our modified ResNet-18 on the input tensor ("x") and then | ||
apply the final classifier layer on the ResNet-18 output to get our | ||
num_classes prediction. | ||
''' | ||
# x.size(): [B x 3 x W x H] | ||
features = self.feature_extractor(x) # features.size(): [B x 512 x W x H] | ||
prediction = self.classifier(features) # prediction.size(): [B x num_classes] | ||
|
||
return prediction |
Oops, something went wrong.