Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Data augmentation #141

Merged
merged 18 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions crabs/detection_tracking/config/faster_rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,29 @@ batch_size_test: 4
# -------------------
# Data augmentation
# -------------------
transform_brightness: 0.5
transform_hue: 0.3
gaussian_blur_params:
gaussian_blur:
kernel_size:
- 5
- 9
sigma:
- 0.1
- 5.0

color_jitter:
brightness: 0.5
hue: 0.3
random_horizontal_flip:
p: 0.5
random_rotation:
degrees: [-10.0, 10.0]
random_adjust_sharpness:
p: 0.5
sharpness_factor: 0.5
random_autocontrast:
p: 0.5
random_equalize:
p: 0.5
clamp_and_sanitize_bboxes:
min_size: 1.0
# ----------------------------
# Hyperparameter optimisation
# -----------------------------
Expand Down
78 changes: 68 additions & 10 deletions crabs/detection_tracking/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,65 @@ def __init__(
list_annotation_files: list[str],
config: dict,
split_seed: Optional[int] = None,
no_data_augmentation: bool = False,
):
super().__init__()
self.list_img_dirs = list_img_dirs
self.list_annotation_files = list_annotation_files
self.split_seed = split_seed
self.config = config
self.no_data_augmentation = no_data_augmentation

def _transform_str_to_operator(self, transform_str):
"""Get transform operator from its name in snake case"""

def snake_to_camel_case(snake_str):
return "".join(
x.capitalize() for x in snake_str.lower().split("_")
)

transform_callable = getattr(
transforms, snake_to_camel_case(transform_str)
)

return transform_callable(**self.config[transform_str])

def _compute_list_of_transforms(self) -> list[torchvision.transforms.v2]:
"""Read transforms from config and add to list"""

# Initialise list
train_data_augm: list[torchvision.transforms.v2] = []

# Apply standard transforms if defined
for transform_str in [
"gaussian_blur",
"color_jitter",
"random_horizontal_flip",
"random_rotation",
"random_adjust_sharpness",
"random_autocontrast",
"random_equalize",
]:
if transform_str in self.config:
transform_operator = self._transform_str_to_operator(
transform_str
)
train_data_augm.append(transform_operator)

# Apply clamp and sanitize bboxes if defined
# See https://pytorch.org/vision/main/generated/torchvision.transforms.v2.SanitizeBoundingBoxes.html#torchvision.transforms.v2.SanitizeBoundingBoxes
if "clamp_and_sanitize_bboxes" in self.config:
# Clamp bounding boxes
train_data_augm.append(transforms.ClampBoundingBoxes())

# Sanitize
sanitize = transforms.SanitizeBoundingBoxes(
min_size=self.config["clamp_and_sanitize_bboxes"]["min_size"],
labels_getter=None, # only bboxes are sanitized
)
train_data_augm.append(sanitize)

return train_data_augm

def _get_train_transform(self) -> torchvision.transforms:
"""Define data augmentation transforms for the train set.
Expand All @@ -38,17 +91,22 @@ def _get_train_transform(self) -> torchvision.transforms:
https://pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use
https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_e2e.html#transforms

ToDtype is the recommended replacement for ConvertImageDtype(dtype)
https://pytorch.org/vision/0.17/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype

"""
jitter = transforms.ColorJitter(
brightness=self.config["transform_brightness"],
hue=self.config["transform_hue"],
)
gauss = transforms.GaussianBlur(
kernel_size=self.config["gaussian_blur_params"]["kernel_size"],
sigma=self.config["gaussian_blur_params"]["sigma"],
)
todtype = transforms.ToDtype(torch.float32, scale=True)
train_transforms = [transforms.ToImage(), jitter, gauss, todtype]
# Compute list of transforms to apply
if self.no_data_augmentation:
train_data_augm = []
else:
train_data_augm = self._compute_list_of_transforms()

# Define a Compose transform with them
train_transforms = [
transforms.ToImage(),
*train_data_augm,
transforms.ToDtype(torch.float32, scale=True),
]
return transforms.Compose(train_transforms)

def _get_test_val_transform(self) -> torchvision.transforms:
Expand Down
36 changes: 35 additions & 1 deletion crabs/detection_tracking/detection_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import datetime
import logging
import os
from pathlib import Path
from typing import Any
from typing import Any, Optional

import torch
from lightning.pytorch.loggers import MLFlowLogger

DEFAULT_ANNOTATIONS_FILENAME = "VIA_JSON_combined_coco_gen.json"
Expand Down Expand Up @@ -240,3 +242,35 @@ def slurm_logs_as_artifacts(logger, slurm_job_id):
logger.run_id,
f"{log_filename}.{ext}",
)


def log_data_augm_as_artifacts(logger, data_module):
"""Log data augmentation transforms as artifacts in MLflow."""
for transform_str in ["train_transform", "test_val_transform"]:
logger.experiment.log_text(
text=str(getattr(data_module, f"_get_{transform_str}")()),
artifact_file=f"{transform_str}.txt",
run_id=logger.run_id,
)


def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]:
"""Get checkpoint type (full or weights) from the checkpoint path."""
checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {checkpoint_path}"
)

return checkpoint_type
8 changes: 4 additions & 4 deletions crabs/detection_tracking/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def evaluate_model(self) -> None:
"""
# Create datamodule
data_module = CrabsDataModule(
self.images_dirs,
self.annotation_files,
self.config,
self.seed_n,
list_img_dirs=self.images_dirs,
list_annotation_files=self.annotation_files,
split_seed=self.seed_n,
config=self.config,
)

# Get trained model
Expand Down
91 changes: 40 additions & 51 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import logging
import os
import sys
from pathlib import Path
Expand All @@ -12,6 +11,8 @@

from crabs.detection_tracking.datamodules import CrabsDataModule
from crabs.detection_tracking.detection_utils import (
get_checkpoint_type,
log_data_augm_as_artifacts,
prep_annotation_files,
prep_img_directories,
set_mlflow_run_name,
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self, args):
self.fast_dev_run = args.fast_dev_run
self.limit_train_batches = args.limit_train_batches

# Restart from checkpoint
self.checkpoint_path = args.checkpoint_path

def load_config_yaml(self):
Expand Down Expand Up @@ -158,67 +160,44 @@ def core_training(self) -> lightning.Trainer:
"""
# Create data module
data_module = CrabsDataModule(
self.images_dirs,
self.annotation_files,
self.config,
self.seed_n,
list_img_dirs=self.images_dirs,
list_annotation_files=self.annotation_files,
split_seed=self.seed_n,
config=self.config,
no_data_augmentation=self.args.no_data_augmentation,
)

# Get checkpoint type
if self.checkpoint_path and os.path.exists(self.checkpoint_path):
checkpoint = torch.load(self.checkpoint_path)
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = None

# Get model
if checkpoint_type == "weights":
# Note: weights-only checkpoint contains hyperparameters
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
optuna_log=self.args.optuna,
# overwrite checkpoint hyperparameters with config ones
# otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used
)
else:
if not self.checkpoint_path:
lightning_model = FasterRCNN(
self.config, optuna_log=self.args.optuna
)
checkpoint_type = None
else:
checkpoint_type = get_checkpoint_type(self.checkpoint_path)
if checkpoint_type == "weights":
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config, # overwrite hparams from ckpt with config
optuna_log=self.args.optuna,
) # a 'weights' checkpoint is one saved with `save_weights_only=True`

# Get trainer
trainer = self.setup_trainer()
if self.args.log_data_augmentation:
log_data_augm_as_artifacts(trainer.logger, data_module)

# Run training
# Resume from full checkpoint if available
# (automatically restores model, epoch, step, LR schedulers, etc...)
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
if checkpoint_type == "full":
trainer.fit(
lightning_model,
data_module,
ckpt_path=self.checkpoint_path, # needs to having been saved with `save_weights_only=False`
)
else: # for "weights" or no checkpoint
trainer.fit(
lightning_model,
data_module,
)
trainer.fit(
lightning_model,
data_module,
ckpt_path=(
self.checkpoint_path if checkpoint_type == "full" else None
),
# a 'full' checkpoint is one saved with `save_weights_only=False`
# (automatically restores model, epoch, step, LR schedulers, etc...)
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
)

return trainer

Expand Down Expand Up @@ -344,6 +323,16 @@ def train_parse_args(args):
action="store_true",
help="Run a hyperparameter optimisation using Optuna prior to training the model",
)
parser.add_argument(
"--no_data_augmentation",
action="store_true",
help="Ignore the data augmentation transforms defined in config file",
)
parser.add_argument(
"--log_data_augmentation",
action="store_true",
help="Log data augmentation transforms linked to datamodule as MLflow artifacts",
)
return parser.parse_args(args)


Expand Down
45 changes: 45 additions & 0 deletions notebooks/notebook_data_augm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# %%
import yaml # type: ignore

from crabs.detection_tracking.datamodules import CrabsDataModule
from crabs.detection_tracking.visualization import plot_sample

# %%%%%%%%%%%%%%%%%%%
# Input data
IMG_DIR = "/home/sminano/swc/project_crabs/data/sep2023-full/frames"
ANNOT_FILE = "/home/sminano/swc/project_crabs/data/sep2023-full/annotations/VIA_JSON_combined_coco_gen.json"
CONFIG = "/home/sminano/swc/project_crabs/crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml"
SPLIT_SEED = 42

# %%%%%%%%%%%%%%%%%%%%
# Read config as dict
with open(CONFIG, "r") as f:
config_dict = yaml.safe_load(f)

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Create datamodule for the input data
dm = CrabsDataModule(
list_img_dirs=[IMG_DIR],
list_annotation_files=[ANNOT_FILE],
config=config_dict,
split_seed=SPLIT_SEED,
)
# %%%%%%%%%%%%%%%%%%%%%%%%
# Setup for train / test
dm.prepare_data()
dm.setup("fit")


# %%%%%%%%%%%%%%%%%%%%%%%%%%%
# after this: dm.train_dataset should have transforms, (but not dm.test_dataset)
print(dm.train_transform)
print(dm.val_transform)
print(dm.test_transform)

# %%%%%%%%%%%%%%%%%%%%%%%%%
# visualize
train_dataset = dm.train_dataset
train_sample = train_dataset[0]
plot_sample([train_sample])

# %%
Loading