Skip to content

Commit

Permalink
impl. working version
Browse files Browse the repository at this point in the history
  • Loading branch information
bkellenb committed Aug 3, 2022
1 parent df30896 commit fbc3ec3
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 33 deletions.
11 changes: 6 additions & 5 deletions configs/exp_resnet18.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
# In Python, this is all read as a dict.

# environment/computational parameters
device: CUDA
device: cuda
num_workers: 4

# dataset parameters
data_root: /path/to/dataset
num_classes: 32
data_root: datasets/CaltechCT
num_classes: 16

# training hyperparameters
image_size: [224, 224]
num_epochs: 200
batch_size: 128
learning_rate: 1e-3
weight_decay: 1e-3
learning_rate: 0.001
weight_decay: 0.001
40 changes: 28 additions & 12 deletions ct_classifier/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,23 @@
import os
import json
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose, Resize, ToTensor
from PIL import Image


class CTDataset(Dataset):

def __init__(self, data_root, split='train'):
def __init__(self, cfg, split='train'):
'''
Constructor. Here, we collect and index the dataset inputs and
labels.
'''
self.data_root = data_root
self.data_root = cfg['data_root']
self.split = split
self.transform = ToTensor()
self.transform = Compose([ # Transforms. Here's where we could add data augmentation (see Björn's lecture on August 11).
Resize((cfg['image_size'])), # For now, we just resize the images to the same dimensions...
ToTensor() # ...and convert them to torch.Tensor.
])

# index data into list
self.data = []
Expand All @@ -40,6 +43,23 @@ def __init__(self, data_root, split='train'):
'train_annotations.json' if self.split=='train' else 'cis_val_annotations.json'
)
meta = json.load(open(annoPath, 'r'))

images = dict([[i['id'], i['file_name']] for i in meta['images']]) # image id to filename lookup
labels = dict([[c['id'], idx] for idx, c in enumerate(meta['categories'])]) # custom labelclass indices that start at zero

# since we're doing classification, we're just taking the first annotation per image and drop the rest
images_covered = set() # all those images for which we have already assigned a label
for anno in meta['annotations']:
imgID = anno['image_id']
if imgID in images_covered:
continue

# append image-label tuple to data
imgFileName = images[imgID]
label = anno['category_id']
labelIndex = labels[label]
self.data.append([imgFileName, labelIndex])
images_covered.add(imgID) # make sure image is only added once to dataset


def __len__(self):
Expand All @@ -54,17 +74,13 @@ def __getitem__(self, idx):
Returns a single data point at given idx.
Here's where we actually load the image.
'''
image_name, label = self.data[idx]
image_name, label = self.data[idx] # see line 57 above where we added these two items to the self.data list

# load image
image_path = os.path.join(self.data_root, image_path)
img = Image.open(image_path)
image_path = os.path.join(self.data_root, 'eccv_18_all_images_sm', image_name)
img = Image.open(image_path).convert('RGB') # the ".convert" makes sure we always get three bands in Red, Green, Blue order

# transform: convert to torch.Tensor
# here's where we could do data augmentation:
# https://pytorch.org/vision/stable/transforms.html
# see Björn's lecture on Thursday, August 11.
# For now, we only convert the image to torch.Tensor
# transform: see lines 31ff above where we define our transformations
img_tensor = self.transform(img)

return img_tensor, label
4 changes: 2 additions & 2 deletions ct_classifier/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(self, num_classes):
# replace the very last layer from the original, 1000-class output
# ImageNet to a new one that outputs num_classes
last_layer = self.feature_extractor.fc # tip: print(self.feature_extractor) to get info on how model is set up
num_features = last_layer.num_features
in_features = last_layer.in_features # number of input dimensions to last (classifier) layer
self.feature_extractor.fc = nn.Identity() # discard last layer...

self.classifier = nn.Linear(num_features, num_classes) # ...and create a new one
self.classifier = nn.Linear(in_features, num_classes) # ...and create a new one


def forward(self, x):
Expand Down
30 changes: 18 additions & 12 deletions ct_classifier/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
2022 Benjamin Kellenberger
'''

import os
import argparse
import yaml
import glob
Expand All @@ -16,8 +17,8 @@
from torch.optim import SGD

# let's import our own classes and functions!
from ct_classifier.dataset import CTDataset
from ct_classifier.model import CustomResNet18
from dataset import CTDataset
from model import CustomResNet18



Expand All @@ -26,7 +27,7 @@ def create_dataloader(cfg, split='train'):
Loads a dataset according to the provided split and wraps it in a
PyTorch DataLoader object.
'''
dataset_instance = CTDataset(cfg['data_root'], split) # create an object instance of our CTDataset class
dataset_instance = CTDataset(cfg, split) # create an object instance of our CTDataset class

dataLoader = DataLoader(
dataset=dataset_instance,
Expand Down Expand Up @@ -66,6 +67,9 @@ def load_model(cfg):


def save_model(epoch, model, stats):
# make sure save directory exists; create if not
os.makedirs('model_states', exist_ok=True)

# get model parameters and add to stats...
stats['model'] = model.state_dict()

Expand Down Expand Up @@ -100,11 +104,11 @@ def train(cfg, dataLoader, model, optimizer):
criterion = nn.CrossEntropyLoss()

# running averages
loss_total, oa_total = 0.0, 0.0 # for now, we just log the loss and overall accuracy (OA)
loss_total, oa_total = 0.0, 0.0 # for now, we just log the loss and overall accuracy (OA)

# iterate over dataLoader
progressBar = trange(len(dataLoader))
for idx, (data, labels) in enumerate(dataLoader):
for idx, (data, labels) in enumerate(dataLoader): # see the last line of file "dataset.py" where we return the image tensor (data) and label

# put data and labels on device
data, labels = data.to(device), labels.to(device)
Expand All @@ -125,18 +129,19 @@ def train(cfg, dataLoader, model, optimizer):
optimizer.step()

# log statistics
loss_total += loss.item() # the .item() command retrieves the value of a single-valued tensor, regardless of its data type and device of tensor
loss_total += loss.item() # the .item() command retrieves the value of a single-valued tensor, regardless of its data type and device of tensor

pred_label = torch.argmax(prediction) # the predicted label is the one at position (class index) with highest predicted value
pred_label = torch.argmax(prediction, dim=1) # the predicted label is the one at position (class index) with highest predicted value
oa = torch.mean((pred_label == labels).float()) # OA: number of correct predictions divided by batch size (i.e., average/mean)
oa_total += oa.item()

progressBar.set_description(
'[Train] Loss: {:.2f}; OA: {:.2f}'.format(
'[Train] Loss: {:.2f}; OA: {:.2f}%'.format(
loss_total/(idx+1),
oa_total/(idx+1)
100*oa_total/(idx+1)
)
)
progressBar.update(1)

# end of epoch; finalize
progressBar.close()
Expand Down Expand Up @@ -179,16 +184,17 @@ def validate(cfg, dataLoader, model):
# log statistics
loss_total += loss.item()

pred_label = torch.argmax(prediction)
pred_label = torch.argmax(prediction, dim=1)
oa = torch.mean((pred_label == labels).float())
oa_total += oa.item()

progressBar.set_description(
'[Val ] Loss: {:.2f}; OA: {:.2f}'.format(
'[Val ] Loss: {:.2f}; OA: {:.2f}%'.format(
loss_total/(idx+1),
oa_total/(idx+1)
100*oa_total/(idx+1)
)
)
progressBar.update(1)

# end of epoch; finalize
progressBar.close()
Expand Down
2 changes: 1 addition & 1 deletion license
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) Microsoft Corporation. All rights reserved.
Copyright (c) ECEO, EPFL. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
4 changes: 3 additions & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ pip install -r requirements.txt

3. Download dataset

**NOTE:** Requires the [azcopy CLI](https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10) to be installed and set up on your machine.

```bash
./scripts/download_dataset.sh
sh scripts/download_dataset.sh
```

This downloads the [CCT20](https://lila.science/datasets/caltech-camera-traps) subset to the `datasets/CaltechCT` folder.
Expand Down

0 comments on commit fbc3ec3

Please sign in to comment.