Skip to content

Commit

Permalink
Merge pull request #6 from shamspias/feat/sahi
Browse files Browse the repository at this point in the history
Fix sahi implementation
  • Loading branch information
shamspias authored Sep 9, 2024
2 parents 8277bd3 + 101bea2 commit 4450564
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 3 deletions.
1 change: 1 addition & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class Config(BaseSettings):
streamlit_title: Optional[str] = "VideoLabelMagic"
debug: Optional[bool] = False
models_directory: Optional[str] = "models/"
output_directory: Optional[str] = "outputs/"
object_class_directory: Optional[str] = "object_class/"
Expand Down
2 changes: 1 addition & 1 deletion app/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, config, video_path, frame_rate, output_dir, model_path, class

# Only initialize SahiUtils if SAHI is enabled
if sahi_config:
self.sahi_utils = SahiUtils(os.path.join('models', model_path), **sahi_config)
self.sahi_utils = SahiUtils(self.config.debug, os.path.join('models', model_path), **sahi_config)
else:
self.sahi_utils = None

Expand Down
1 change: 1 addition & 0 deletions example.env
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Streamlit Configuration
STREAMLIT_TITLE=VideoLabelMagic
DEBUG=True

# Directories
MODELS_DIRECTORY=models/
Expand Down
48 changes: 46 additions & 2 deletions utils/sahi_utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,25 @@
import cv2
import uuid
import time
import os
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction
from sahi.utils.cv import read_image_as_pil
from sahi import AutoDetectionModel
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


class SahiUtils:
def __init__(self, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256), overlap_ratio=(0.2, 0.2)):
def __init__(self, debug, model_path, model_type='yolov8', device='cpu', slice_size=(256, 256),
overlap_ratio=(0.2, 0.2)):
self.debug = debug
self.device = device # Can be 'cpu' or 'cuda:0' for GPU
self.model_type = model_type
self.model = self.load_model(model_path)
self.slice_size = slice_size
self.overlap_ratio = overlap_ratio
self.debug_annotated_directory = str(uuid.uuid4())

def load_model(self, model_path):
"""Loads a detection model based on the specified type and path."""
Expand All @@ -22,6 +31,23 @@ def load_model(self, model_path):
)
return detection_model

def show_image(self, image, title="Image"):
"""Displays a NumPy image using matplotlib."""
# Convert BGR to RGB for correct color
plt.imshow(image if len(image.shape) == 2 else cv2.cvtColor(
image,
cv2.COLOR_BGR2RGB
))
plt.title(title)
plt.axis('off') # Hide axes
plt.show()

def show_annotated_image(self, image_path):
img = Image.open(image_path)
plt.imshow(img)
plt.axis('off')
plt.show()

def perform_sliced_inference(self, image):
"""Performs object detection on an image using sliced prediction."""
pil_image = read_image_as_pil(image)
Expand All @@ -34,6 +60,24 @@ def perform_sliced_inference(self, image):
overlap_width_ratio=self.overlap_ratio[1],
verbose=False
)
if self.debug:
random_value = str(uuid.uuid4())
# Start exporting the image
results.export_visuals(export_dir=f"temp/{self.debug_annotated_directory}/", file_name=random_value)

# Wait until the file is created
file_path = f"temp/{self.debug_annotated_directory}/{random_value}.png"
timeout = 10 # Set a timeout limit of 10 seconds or more if necessary
start_time = time.time()

while not os.path.exists(file_path):
if time.time() - start_time > timeout:
raise TimeoutError(f"File creation exceeded {timeout} seconds.")
time.sleep(0.1) # Wait for 100 milliseconds before checking again

# Once the file exists, display it
self.show_annotated_image(file_path)

return self.format_predictions(results)

def format_predictions(self, prediction_result):
Expand Down

0 comments on commit 4450564

Please sign in to comment.