Skip to content

Optionally Reduce LR on plateau #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 64 commits into
base: master
Choose a base branch
from

Conversation

johnlockejrr
Copy link

Learning Rate Warmup and Optimization Implementation

Overview

Added learning rate warmup functionality to improve training stability, especially when using pretrained weights. The implementation uses TensorFlow's native learning rate scheduling for better performance.

Changes Made

1. Configuration Updates (runs/train_no_patches_448x448.json)

Added new configuration parameters for warmup:

{
    "warmup_enabled": true,
    "warmup_epochs": 5,
    "warmup_start_lr": 1e-6
}

2. Training Script Updates (train.py)

A. Optimizer and Learning Rate Schedule

  • Replaced fixed learning rate with dynamic scheduling
  • Implemented warmup using tf.keras.optimizers.schedules.PolynomialDecay
  • Maintained compatibility with existing ReduceLROnPlateau and EarlyStopping
if warmup_enabled:
    lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=warmup_start_lr,
        decay_steps=warmup_epochs * steps_per_epoch,
        end_learning_rate=learning_rate,
        power=1.0  # Linear decay
    )
    optimizer = Adam(learning_rate=lr_schedule)
else:
    optimizer = Adam(learning_rate=learning_rate)

B. Learning Rate Behavior

  • Initial learning rate: 1e-6 (configurable via warmup_start_lr)
  • Target learning rate: 5e-5 (configurable via learning_rate)
  • Linear increase over 5 epochs (configurable via warmup_epochs)
  • After warmup, learning rate remains at target value until ReduceLROnPlateau triggers

Benefits

  1. Improved training stability during initial epochs
  2. Better handling of pretrained weights
  3. Efficient implementation using TensorFlow's native scheduling
  4. Configurable through JSON configuration file
  5. Maintains compatibility with existing callbacks (ReduceLROnPlateau, EarlyStopping)

Usage

To enable warmup:

  1. Set warmup_enabled: true in the configuration file
  2. Adjust warmup_epochs and warmup_start_lr as needed
  3. The warmup will automatically integrate with existing learning rate reduction and early stopping

To disable warmup:

  • Set warmup_enabled: false or remove the warmup parameters from the configuration file

avoid ensembling if no model weights met the threshold f1 score in the case of classification
…n an output directory with the same file name
vahidrezanezhad and others added 29 commits June 12, 2024 17:40
Changed unsafe basename extraction:
`file_name = i.split('.')[0]` to `file_name = os.path.splitext(i)[0]`
and
`filename = n[i].split('.')[0]` to `filename = os.path.splitext(n[i])[0]`
because
`"Vat.sam.2_206.jpg` -> `Vat` instead of `"Vat.sam.2_206`
Keep safely the full basename without extension
# Learning Rate Warmup and Optimization Implementation

## Overview
Added learning rate warmup functionality to improve training stability, especially when using pretrained weights. The implementation uses TensorFlow's native learning rate scheduling for better performance.

## Changes Made

### 1. Configuration Updates (`runs/train_no_patches_448x448.json`)
Added new configuration parameters for warmup:
```json
{
    "warmup_enabled": true,
    "warmup_epochs": 5,
    "warmup_start_lr": 1e-6
}
```

### 2. Training Script Updates (`train.py`)

#### A. Optimizer and Learning Rate Schedule
- Replaced fixed learning rate with dynamic scheduling
- Implemented warmup using `tf.keras.optimizers.schedules.PolynomialDecay`
- Maintained compatibility with existing ReduceLROnPlateau and EarlyStopping

```python
if warmup_enabled:
    lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=warmup_start_lr,
        decay_steps=warmup_epochs * steps_per_epoch,
        end_learning_rate=learning_rate,
        power=1.0  # Linear decay
    )
    optimizer = Adam(learning_rate=lr_schedule)
else:
    optimizer = Adam(learning_rate=learning_rate)
```

#### B. Learning Rate Behavior
- Initial learning rate: 1e-6 (configurable via `warmup_start_lr`)
- Target learning rate: 5e-5 (configurable via `learning_rate`)
- Linear increase over 5 epochs (configurable via `warmup_epochs`)
- After warmup, learning rate remains at target value until ReduceLROnPlateau triggers

## Benefits
1. Improved training stability during initial epochs
2. Better handling of pretrained weights
3. Efficient implementation using TensorFlow's native scheduling
4. Configurable through JSON configuration file
5. Maintains compatibility with existing callbacks (ReduceLROnPlateau, EarlyStopping)

## Usage
To enable warmup:
1. Set `warmup_enabled: true` in the configuration file
2. Adjust `warmup_epochs` and `warmup_start_lr` as needed
3. The warmup will automatically integrate with existing learning rate reduction and early stopping

To disable warmup:
- Set `warmup_enabled: false` or remove the warmup parameters from the configuration file
# Training Script Improvements

## Learning Rate Management Fixes

### 1. ReduceLROnPlateau Implementation
- Fixed the learning rate reduction mechanism by replacing the manual epoch loop with a single `model.fit()` call
- This ensures proper tracking of validation metrics across epochs
- Configured with:
  ```python
  reduce_lr = ReduceLROnPlateau(
      monitor='val_loss',
      factor=0.2,        # More aggressive reduction
      patience=3,        # Quick response to plateaus
      min_lr=1e-6,       # Minimum learning rate
      min_delta=1e-5,    # Minimum change to be considered improvement
      verbose=1
  )
  ```

### 2. Warmup Implementation
- Added learning rate warmup using TensorFlow's native scheduling
- Gradually increases learning rate from 1e-6 to target (2e-5) over 5 epochs
- Helps stabilize initial training phase
- Implemented using `PolynomialDecay` schedule:
  ```python
  lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
      initial_learning_rate=warmup_start_lr,
      decay_steps=warmup_epochs * steps_per_epoch,
      end_learning_rate=learning_rate,
      power=1.0  # Linear decay
  )
  ```

### 3. Early Stopping
- Added early stopping to prevent overfitting
- Configured with:
  ```python
  early_stopping = EarlyStopping(
      monitor='val_loss',
      patience=6,
      restore_best_weights=True,
      verbose=1
  )
  ```

## Model Saving Improvements

### 1. Epoch-based Model Saving
- Implemented custom `ModelCheckpointWithConfig` to save both model and config
- Saves after each epoch with corresponding config.json
- Maintains compatibility with original script's saving behavior

### 2. Best Model Saving
- Saves the best model at training end
- If early stopping triggers: saves the best model from training
- If no early stopping: saves the final model

## Configuration
All parameters are configurable through the JSON config file:
```json
{
    "reduce_lr_enabled": true,
    "reduce_lr_monitor": "val_loss",
    "reduce_lr_factor": 0.2,
    "reduce_lr_patience": 3,
    "reduce_lr_min_lr": 1e-6,
    "reduce_lr_min_delta": 1e-5,
    "early_stopping_enabled": true,
    "early_stopping_monitor": "val_loss",
    "early_stopping_patience": 6,
    "early_stopping_restore_best_weights": true,
    "warmup_enabled": true,
    "warmup_epochs": 5,
    "warmup_start_lr": 1e-6
}
```

## Benefits
1. More stable training with proper learning rate management
2. Better handling of training plateaus
3. Automatic saving of best model
4. Maintained compatibility with existing config saving
5. Improved training monitoring and control
@johnlockejrr
Copy link
Author

With Warmup, ReduceLROnPlateau and early stopping the training process looks smoother and much more stable than the default:

==================================================================================================
Total params: 38,211,247
Trainable params: 38,154,089
Non-trainable params: 57,158
__________________________________________________________________________________________________
Epoch 1/50
1194/1194 [==============================] - ETA: 0s - loss: 0.7081 - accuracy: 0.7516
Epoch 1: saving model to runs/sam_41_mss_npt_448x448/model_1
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_1/assets
1194/1194 [==============================] - 396s 302ms/step - loss: 0.7081 - accuracy: 0.7516 - val_loss: 0.5615 - val_accuracy: 0.8679 - lr: 4.7968e-06
Epoch 2/50
1194/1194 [==============================] - ETA: 0s - loss: 0.5898 - accuracy: 0.8990
Epoch 2: saving model to runs/sam_41_mss_npt_448x448/model_2
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_2/assets
1194/1194 [==============================] - 357s 299ms/step - loss: 0.5898 - accuracy: 0.8990 - val_loss: 0.5170 - val_accuracy: 0.8843 - lr: 8.5968e-06
Epoch 3/50
1194/1194 [==============================] - ETA: 0s - loss: 0.5065 - accuracy: 0.9154
Epoch 3: saving model to runs/sam_41_mss_npt_448x448/model_3
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_3/assets
1194/1194 [==============================] - 356s 298ms/step - loss: 0.5065 - accuracy: 0.9154 - val_loss: 0.4268 - val_accuracy: 0.8995 - lr: 1.2397e-05
Epoch 4/50
1194/1194 [==============================] - ETA: 0s - loss: 0.4244 - accuracy: 0.9236
Epoch 4: saving model to runs/sam_41_mss_npt_448x448/model_4
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_4/assets
1194/1194 [==============================] - 357s 299ms/step - loss: 0.4244 - accuracy: 0.9236 - val_loss: 0.3657 - val_accuracy: 0.9017 - lr: 1.6197e-05
Epoch 5/50
1194/1194 [==============================] - ETA: 0s - loss: 0.3661 - accuracy: 0.9287
Epoch 5: saving model to runs/sam_41_mss_npt_448x448/model_5
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_5/assets
1194/1194 [==============================] - 358s 299ms/step - loss: 0.3661 - accuracy: 0.9287 - val_loss: 0.3451 - val_accuracy: 0.9161 - lr: 1.9997e-05
Epoch 6/50
1194/1194 [==============================] - ETA: 0s - loss: 0.3299 - accuracy: 0.9324
Epoch 6: saving model to runs/sam_41_mss_npt_448x448/model_6
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_6/assets
1194/1194 [==============================] - 356s 298ms/step - loss: 0.3299 - accuracy: 0.9324 - val_loss: 0.3256 - val_accuracy: 0.9052 - lr: 2.0000e-05
Epoch 7/50
1194/1194 [==============================] - ETA: 0s - loss: 0.3069 - accuracy: 0.9355
Epoch 7: saving model to runs/sam_41_mss_npt_448x448/model_7
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_7/assets
1194/1194 [==============================] - 357s 299ms/step - loss: 0.3069 - accuracy: 0.9355 - val_loss: 0.2609 - val_accuracy: 0.9106 - lr: 2.0000e-05
Epoch 8/50
1194/1194 [==============================] - ETA: 0s - loss: 0.2906 - accuracy: 0.9380
Epoch 8: saving model to runs/sam_41_mss_npt_448x448/model_8
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_8/assets
1194/1194 [==============================] - 356s 298ms/step - loss: 0.2906 - accuracy: 0.9380 - val_loss: 0.2852 - val_accuracy: 0.9165 - lr: 2.0000e-05
Epoch 9/50
1194/1194 [==============================] - ETA: 0s - loss: 0.2774 - accuracy: 0.9403
Epoch 9: saving model to runs/sam_41_mss_npt_448x448/model_9
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_9/assets
1194/1194 [==============================] - 358s 300ms/step - loss: 0.2774 - accuracy: 0.9403 - val_loss: 0.2608 - val_accuracy: 0.9229 - lr: 2.0000e-05
Epoch 10/50
1194/1194 [==============================] - ETA: 0s - loss: 0.2656 - accuracy: 0.9422
Epoch 10: saving model to runs/sam_41_mss_npt_448x448/model_10
INFO - tensorflow - Assets written to: runs/sam_41_mss_npt_448x448/model_10/assets
1194/1194 [==============================] - 357s 299ms/step - loss: 0.2656 - accuracy: 0.9422 - val_loss: 0.2218 - val_accuracy: 0.9341 - lr: 2.0000e-05
Epoch 11/50
 254/1194 [=====>........................] - ETA: 4:26 - loss: 0.2586 - accuracy: 0.9439

We'll see at training's end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants