This repository contains an implementation of image segmentation using PyTorch and the U-Net architecture. The project focuses on human segmentation using the EfficientNet-B0 encoder and custom training pipeline.
The project implements a complete image segmentation pipeline including:
- Custom dataset handling
- Data augmentation
- U-Net architecture with EfficientNet-B0 encoder
- Combined loss function (Dice + BCE)
- Training and validation loops
- Model checkpointing
torch
opencv-python
numpy
pandas
matplotlib
scikit-learn
tqdm
albumentations
segmentation-models-pytorch
├── Dataset(folder) # Training data information
├── Deep_Learning_with_PyTorch_ImageSegmentation.ipynb # Main implementation file
└── helper.py # Helper functions for visualization
- Clone the repository:
git clone https://github.com/amangupta143/PyTorch-Image-Segmentation.git
cd PyTorch-Image-Segmentation
- Install required packages:
pip install segmentation-models-pytorch
pip install -U albumentations
pip install opencv-contrib-python
- Download the dataset:
git clone https://github.com/parth1620/Human-Segmentation-Dataset-master.git
- Base Architecture: U-Net
- Encoder: EfficientNet-B0 (pretrained on ImageNet)
- Input Channels: 3 (RGB)
- Output Classes: 1 (Binary Segmentation)
- Loss Function: Combination of Dice Loss and Binary Cross-Entropy Loss
EPOCHS = 35
LEARNING_RATE = 0.003
IMAGE_SIZE = 320
BATCH_SIZE = 16
ENCODER = 'timm-efficientnet-b0'
WEIGHTS = 'imagenet'
Training augmentations include:
- Resize to 320x320
- Horizontal Flip (50% probability)
- Vertical Flip (50% probability)
The training process includes:
- Custom training and validation functions
- Model checkpointing for best validation loss
- Adam optimizer
- GPU acceleration support
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
- Prepare your dataset and update the CSV_FILE path in the configuration
- Run the training script:
python deep_learning_with_pytorch_imagesegmentation.py
The model can be used for inference as follows:
model.load_state_dict(torch.load('bestModel.pt'))
image, mask = validset[idx]
logits_mask = model(image.to(DEVICE).unsqueeze(0))
pred_mask = torch.sigmoid(logits_mask)
pred_mask = (pred_mask > 0.5) * 1.0
Dataset originally from: Human-Segmentation-Dataset
MIT License
Feel free to use this implementation and modify it according to your needs. Contributions are welcome!