diff --git a/app/config.py b/app/config.py index 31062f3..d8a4d5d 100644 --- a/app/config.py +++ b/app/config.py @@ -5,6 +5,7 @@ class Config(BaseSettings): streamlit_title: Optional[str] = "VideoLabelMagic" + debug: Optional[bool] = False models_directory: Optional[str] = "models/" output_directory: Optional[str] = "outputs/" object_class_directory: Optional[str] = "object_class/" diff --git a/app/extractor.py b/app/extractor.py index b6d1551..25b0002 100644 --- a/app/extractor.py +++ b/app/extractor.py @@ -27,7 +27,7 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class # Only initialize SahiUtils if SAHI is enabled if sahi_config: - self.sahi_utils = SahiUtils(os.path.join('models', model_path), **sahi_config) + self.sahi_utils = SahiUtils(self.config.debug, os.path.join('models', model_path), **sahi_config) else: self.sahi_utils = None diff --git a/example.env b/example.env index ce763d6..b7fc799 100644 --- a/example.env +++ b/example.env @@ -1,5 +1,6 @@ # Streamlit Configuration STREAMLIT_TITLE=VideoLabelMagic +DEBUG=True # Directories MODELS_DIRECTORY=models/ diff --git a/utils/sahi_utils.py b/utils/sahi_utils.py index d05f6e1..675c23e 100644 --- a/utils/sahi_utils.py +++ b/utils/sahi_utils.py @@ -1,16 +1,25 @@ +import cv2 +import uuid +import time +import os +from sahi import AutoDetectionModel from sahi.predict import get_sliced_prediction from sahi.utils.cv import read_image_as_pil -from sahi import AutoDetectionModel import numpy as np +import matplotlib.pyplot as plt +from PIL import Image class SahiUtils: - def __init__(self, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256), overlap_ratio=(0.2, 0.2)): + def __init__(self, debug, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256), + overlap_ratio=(0.2, 0.2)): + self.debug = debug self.device = device # Can be 'cpu' or 'cuda:0' for GPU self.model_type = model_type self.model = self.load_model(model_path) self.slice_size = slice_size self.overlap_ratio = overlap_ratio + self.debug_annotated_directory = str(uuid.uuid4()) def load_model(self, model_path): """Loads a detection model based on the specified type and path.""" @@ -22,6 +31,23 @@ def load_model(self, model_path): ) return detection_model + def show_image(self, image, title="Image"): + """Displays a NumPy image using matplotlib.""" + # Convert BGR to RGB for correct color + plt.imshow(image if len(image.shape) == 2 else cv2.cvtColor( + image, + cv2.COLOR_BGR2RGB + )) + plt.title(title) + plt.axis('off') # Hide axes + plt.show() + + def show_annotated_image(self, image_path): + img = Image.open(image_path) + plt.imshow(img) + plt.axis('off') + plt.show() + def perform_sliced_inference(self, image): """Performs object detection on an image using sliced prediction.""" pil_image = read_image_as_pil(image) @@ -34,6 +60,24 @@ def perform_sliced_inference(self, image): overlap_width_ratio=self.overlap_ratio[1], verbose=False ) + if self.debug: + random_value = str(uuid.uuid4()) + # Start exporting the image + results.export_visuals(export_dir=f"temp/{self.debug_annotated_directory}/", file_name=random_value) + + # Wait until the file is created + file_path = f"temp/{self.debug_annotated_directory}/{random_value}.png" + timeout = 10 # Set a timeout limit of 10 seconds or more if necessary + start_time = time.time() + + while not os.path.exists(file_path): + if time.time() - start_time > timeout: + raise TimeoutError(f"File creation exceeded {timeout} seconds.") + time.sleep(0.1) # Wait for 100 milliseconds before checking again + + # Once the file exists, display it + self.show_annotated_image(file_path) + return self.format_predictions(results) def format_predictions(self, prediction_result):