Skip to content

Commit

Permalink
Merge pull request #7 from shamspias/feat/rt-detr
Browse files Browse the repository at this point in the history
Enhance Filtering of Annotations by Model and Class in BaseFormat
  • Loading branch information
shamspias authored Sep 10, 2024
2 parents 4450564 + ee9168e commit 496df74
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 51 deletions.
37 changes: 16 additions & 21 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,35 @@
# VideoLabelMagic

VideoLabelMagic is a Streamlit-based application tailored for researchers and developers in the computer vision field.
It simplifies the process of extracting frames from videos, applying object detection using a YOLO model, and annotating
these frames to generate training data.
VideoLabelMagic is a Streamlit-based application tailored for researchers and developers in the computer vision field. It simplifies the process of extracting frames from videos, applying object detection using various models like YOLO, NAS, and RTDETR, and annotating these frames to generate training data.

## Features

- **Video Upload**: Upload video files via the web interface.
- **Model Selection**: Utilize pre-trained YOLO models for object detection.
- **Model Selection**: Utilize pre-trained models for object detection, with support for multiple models such as YOLO, NAS, and RTDETR.
- **Multi-Model Operation**: Configure and run multiple detection models simultaneously to leverage their strengths in diverse scenarios.
- **Frame Rate Control**: Adjust the frame rate for extracting images from the video.
- **Dynamic Class Configuration**: Use YAML files to define and utilize different class configurations for object
detection.
- **Dynamic Class Configuration**: Use YAML files to define and utilize different class configurations for object detection.
- **Output Customization**: Configure output directories for storing extracted frames and annotations.
- **Transformation Options**: Apply transformations such as resizing, converting to grayscale, or rotating frames.
- **Flexible Storage**: Choose between local file system or cloud-based object storage for input/output operations.
- **SAHI Integration**: Use SAHI for sliced predictions, allowing efficient handling of large or complex images.

## Usage

1. **Starting the Application**:
- Launch the application and access it via `http://localhost:8501` on your browser.
- Launch the application and access it via `http://localhost:8501` on your browser.

2. **Uploading and Configuring**:
- Upload a video file or select one from the configured object storage.
- Choose the detection model, class configuration, and specify the output directory and frame rate.
- Select desired transformations for the frames to be processed.
- Upload a video file or select one from the configured object storage.
- Choose one or multiple detection models, class configuration, and specify the output directory and frame rate.
- Select desired transformations for the frames to be processed.

3. **Processing**:
- Click "Extract Frames" to start the frame extraction and annotation process.
- Once processing completes, the outputs can be found in the specified directory or uploaded to cloud storage.
- Click "Extract Frames" to start the frame extraction and annotation process.
- Once processing completes, the outputs can be found in the specified directory or uploaded to cloud storage.

4. **Viewing Results**:
- Access extracted images and annotations directly from the output directory or your cloud storage interface.
- Access extracted images and annotations directly from the output directory or your cloud storage interface.

## Creating Class Configuration Files

Expand All @@ -46,14 +45,11 @@ Here's how to set up your YAML file for dynamic class configuration:
- id: 1
name: car
- id: 2
name: truck
- id: 3
name: tank
name: truck
```
2. **Saving the File**: Save the file with a `.yaml` extension in the `object_class/` directory.
3. **Using in Application**: When running the application, select your new class configuration file from the dropdown
menu.
3. **Using in Application**: When running the application, select your new class configuration file from the dropdown menu.

### Prerequisites

Expand All @@ -80,11 +76,10 @@ Here's how to set up your YAML file for dynamic class configuration:

## Contributing

Contributions to VideoLabelMagic are welcome! Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on
how to make contributions.
Contributions to VideoLabelMagic are welcome! Please refer to the [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to make contributions.

## License

Distributed under the MIT License. See [LICENSE](LICENSE) for more information.

Powered by [Indikat](https://indikat.tech)
Powered by [Indikat](https://indikat.tech)
40 changes: 33 additions & 7 deletions app/extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import cv2
import os
from ultralytics import YOLO
from ultralytics import YOLO, RTDETR, NAS
import yaml
from utils.image_processor import ImageProcessor
from utils.sahi_utils import SahiUtils
Expand All @@ -13,16 +13,17 @@ class VideoFrameExtractor:
"""

def __init__(self, config, video_path, frame_rate, output_dir, model_path, class_config_path, output_format,
transformations, sahi_config=None):
transformations, model_types, sahi_config=None):
self.config = config
self.video_path = video_path # Ensure this is a string representing the path to the video file.
self.frame_rate = frame_rate
self.output_dir = output_dir
self.yolo_model = YOLO(os.path.join('models', model_path))
self.vision_model = self.get_given_model(model_path, model_types)
self.class_config_path = class_config_path
self.output_format = output_format
self.transformations = transformations
self.supported_classes = self.load_classes(self.class_config_path)
self.supported_classes_names = self.load_classes_names(self.class_config_path)
self.supported_classes_ids = self.load_classes_ids(self.class_config_path)
self.image_processor = ImageProcessor(output_size=self.transformations.get('size', (640, 640)))

# Only initialize SahiUtils if SAHI is enabled
Expand All @@ -37,7 +38,18 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class
else:
print(f"VideoFrameExtractor initialized with video path: {self.video_path}")

def load_classes(self, config_path):
def get_given_model(self, model_path, types):
try:
if types == "RTDETR":
return RTDETR(os.path.join('models', model_path))
elif types == "YOLO":
return YOLO(os.path.join('models', model_path))
elif types == "NAS":
return NAS(os.path.join('models', model_path))
except Exception as e:
raise ValueError(f"Model architecture and Model not Matching: {str(e)}")

def load_classes_names(self, config_path):
"""
Load classes from a YAML configuration file.
"""
Expand All @@ -47,6 +59,16 @@ def load_classes(self, config_path):
class_data = yaml.safe_load(file)
return [cls['name'] for cls in class_data['classes']]

def load_classes_ids(self, config_path):
"""
Load classes from a YAML configuration file.
"""
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file not found at {config_path}")
with open(config_path, 'r') as file:
class_data = yaml.safe_load(file)
return [cls['id'] for cls in class_data['classes']]

def extract_frames(self, model_confidence):
cap = cv2.VideoCapture(self.video_path)
if not cap.isOpened():
Expand Down Expand Up @@ -75,13 +97,17 @@ def extract_frames(self, model_confidence):
if self.sahi_utils:
results = self.sahi_utils.perform_sliced_inference(transformed_image)
else:
results = self.yolo_model.predict(transformed_image, conf=model_confidence, verbose=False)
if self.config.debug:
results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False)
# will add image show later time
else:
results = self.vision_model.predict(transformed_image, conf=model_confidence, verbose=False)

# print(results)

self.output_format.save_annotations(transformed_image, frame_path, frame_filename,
results,
self.supported_classes)
self.supported_classes_names, self.supported_classes_ids)

frame_count += 1

Expand Down
3 changes: 2 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def continue_ui(self):
'rotate': 'Rotate 90 degrees' in transformation_options
}
self.format_selection = st.selectbox("Choose output format:", list(self.format_options.keys()))
self.model_types = st.selectbox("Choose Model Types:", ("YOLO", "RTDETR", "NAS"))
self.sahi_enabled = st.sidebar.checkbox("Enable SAHI", value=self.config.sahi_enabled)
if self.sahi_enabled:
self.config.sahi_model_type = st.sidebar.selectbox("Model Architecture:", ["yolov8", "yolov9", "yolov10"])
Expand Down Expand Up @@ -116,7 +117,7 @@ def run_extraction(self, video_path, unique_filename):
try:
extractor = VideoFrameExtractor(self.config, video_path, self.frame_rate, specific_output_dir,
self.model_selection, class_config_path, output_format_instance,
self.transformations, self.sahi_config)
self.transformations, self.model_types, self.sahi_config)
extractor.extract_frames(self.model_confidence)
if self.format_selection == "CVAT":
output_format_instance.zip_and_cleanup()
Expand Down
36 changes: 20 additions & 16 deletions formats/base_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@ def ensure_directories(self):
"""
raise NotImplementedError("Subclasses should implement this method.")

def process_results(self, frame, results: Dict, img_dimensions) -> List[str]:
def process_results(self, results: Dict, img_dimensions, supported_classes) -> List[str]:
"""
Generate formatted strings from detection results suitable for annotations.
Args:
frame: The image frame being processed.
# frame: The image frame being processed.
results: Detection results containing bounding boxes and class IDs.
img_dimensions: Dimensions of the image for normalizing coordinates.
supported_classes: List of supported class
Returns:
List of annotation strings formatted according to specific requirements.
Expand All @@ -64,28 +65,30 @@ def process_results(self, frame, results: Dict, img_dimensions) -> List[str]:
if self.sahi_enabled:
for box in results['boxes']: # Assuming SAHI results are formatted similarly
class_id = int(box['cls'][0])
xmin, ymin, xmax, ymax = box['xyxy'][0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
if class_id in supported_classes: # Check if class_id is in the list of supported classes
xmin, ymin, xmax, ymax = box['xyxy'][0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
else:
for result in results:
if hasattr(result, 'boxes') and result.boxes is not None:
for box in result.boxes:
class_id = int(box.cls[0])
xmin, ymin, xmax, ymax = box.xyxy[0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")
if class_id in supported_classes: # Check if class_id is in the list of supported classes
xmin, ymin, xmax, ymax = box.xyxy[0]
x_center = ((xmin + xmax) / 2) / img_width
y_center = ((ymin + ymax) / 2) / img_height
width = (xmax - xmin) / img_width
height = (ymax - ymin) / img_height
annotations.append(f"{class_id} {x_center:.6f} {y_center:.6f} {width:.6f} {height:.6f}")

return annotations

def save_annotations(self, frame, frame_path: str, frame_filename: str, results: Dict,
supported_classes: List[str]):
supported_classes_names: List[str], supported_classes_ids: List[str]):
"""
Abstract method for saving annotations. To be implemented by subclasses to define
the logic for saving the annotations.
Expand All @@ -95,7 +98,8 @@ def save_annotations(self, frame, frame_path: str, frame_filename: str, results:
frame_path (str): The path where the frame is located.
frame_filename (str): The name of the frame file.
results (Dict): A dictionary of results from the detection model or sliced inference.
supported_classes (List[str]): List of supported class labels for the annotations.
supported_classes_names (List[str]): List of supported class labels names for the annotations.
supported_classes_ids (List[str]): List of supported class labels ids for the annotations.
Raises:
NotImplementedError: If the method is not implemented in the subclass.
Expand Down
7 changes: 4 additions & 3 deletions formats/cvat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ def __init__(self, output_dir: str, sahi_enabled: bool = False):
self.image_dir = os.path.join(self.data_dir, 'obj_train_data')
os.makedirs(self.image_dir, exist_ok=True)

def save_annotations(self, frame, frame_path: str, frame_filename: str, results, supported_classes: List[str]):
def save_annotations(self, frame, frame_path: str, frame_filename: str, results, supported_classes_names: List[str],
supported_classes_ids: List[str]):
"""
Saves annotations and frames in a format compatible with CVAT.
"""
img_dimensions = frame.shape[:2]
annotations = self.process_results(frame, results, img_dimensions)
annotations = self.process_results(results, img_dimensions, supported_classes_ids)
frame_filename_png = frame_filename.replace('.jpg', '.png')
image_path = os.path.join(self.image_dir, frame_filename_png)
cv2.imwrite(image_path, frame)
self.write_annotations(frame_filename_png, annotations)
self.create_metadata_files(supported_classes)
self.create_metadata_files(supported_classes_names)

def write_annotations(self, frame_filename: str, annotations: List[str]):
"""
Expand Down
7 changes: 4 additions & 3 deletions formats/roboflow_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,12 @@ def write_annotations(self, frame_filename: str, annotations: List[str]):
for annotation in annotations:
file.write(annotation + "\n")

def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes):
def save_annotations(self, frame, frame_path, frame_filename, results, supported_classes_names: List[str],
supported_classes_ids: List[str]):
img_dimensions = frame.shape[:2]
annotations = self.process_results(frame, results, img_dimensions)
annotations = self.process_results(results, img_dimensions, supported_classes_ids)
self.write_annotations(frame_filename, annotations)
self.create_data_yaml(supported_classes)
self.create_data_yaml(supported_classes_names)

def create_data_yaml(self, supported_classes):
"""
Expand Down
7 changes: 7 additions & 0 deletions object_class/test_class_coco_1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
classes:
- id: 0
name: person
- id: 2
name: car
- id: 7
name: truck

0 comments on commit 496df74

Please sign in to comment.