diff --git a/crabs/tracker/evaluate_tracker.py b/crabs/tracker/evaluate_tracker.py index e40bfe5d..088658ef 100644 --- a/crabs/tracker/evaluate_tracker.py +++ b/crabs/tracker/evaluate_tracker.py @@ -2,11 +2,15 @@ import csv import logging +from pathlib import Path from typing import Any, Optional import numpy as np -from crabs.tracker.utils.tracking import extract_bounding_box_info +from crabs.tracker.utils.tracking import ( + extract_bounding_box_info, + save_tracking_mota_metrics, +) class TrackerEvaluate: @@ -14,9 +18,10 @@ class TrackerEvaluate: def __init__( self, - gt_dir: str, - predicted_boxes_id: list[np.ndarray], + gt_dir: str, # annotations_file + predicted_boxes_dict: dict, iou_threshold: float, + tracking_output_dir: Path, ): """Initialize the TrackerEvaluate class. @@ -27,47 +32,25 @@ def __init__( ---------- gt_dir : str Directory path of the ground truth CSV file. - predicted_boxes_id : list[np.ndarray] - List of numpy arrays containing predicted bounding boxes and IDs. + predicted_boxes_dict : dict + Dictionary mapping frame indices to bounding boxes arrays + (under "tracked_boxes"), ids (under "ids") and detection scores + (under "scores"). The bounding boxes array have shape (n, 4) where + n is the number of boxes in the frame and the 4 columns are (xmin, + ymin, xmax, ymax). iou_threshold : float Intersection over Union (IoU) threshold for evaluating tracking performance. + tracking_output_dir : Path + Path to the directory where the tracking output will be saved. """ self.gt_dir = gt_dir - self.predicted_boxes_id = predicted_boxes_id + self.predicted_boxes_dict = predicted_boxes_dict self.iou_threshold = iou_threshold + self.tracking_output_dir = tracking_output_dir self.last_known_predicted_ids: dict = {} - def get_predicted_data(self) -> dict[int, dict[str, Any]]: - """Format predicted bounding box and ID as dictionary. - - Dictionary keys are frame numbers. - - Returns - ------- - dict[int, dict[str, Any]]: - A dictionary where the key is the frame number and the value is - another dictionary containing: - - 'bbox': A numpy array with shape (N, 4) containing coordinates - of the bounding boxes [x, y, x + width, y + height] for every - object in the frame. - - 'id': A numpy array containing the IDs of the tracked objects. - - """ - predicted_dict: dict[int, dict[str, Any]] = {} - - for frame_number, frame_data in enumerate(self.predicted_boxes_id): - if frame_data.size == 0: - continue - - bboxes = frame_data[:, :4] - ids = frame_data[:, 4] - - predicted_dict[frame_number] = {"bbox": bboxes, "id": ids} - - return predicted_dict - def get_ground_truth_data(self) -> dict[int, dict[str, Any]]: """Fromat ground truth bounding box data as dict with key frame number. @@ -82,6 +65,8 @@ def get_ground_truth_data(self) -> dict[int, dict[str, Any]]: - 'id': The ground truth ID """ + # TODO: refactor with pandas + with open(self.gt_dir) as csvfile: csvreader = csv.reader(csvfile) next(csvreader) # Skip the header row @@ -91,7 +76,10 @@ def get_ground_truth_data(self) -> dict[int, dict[str, Any]]: # Format as a dictionary with key = frame number ground_truth_dict: dict = {} + + # loop thru annotations for data in ground_truth_data: + # Get frame, bbox, id frame_number = data["frame_number"] bbox = np.array( [ @@ -104,9 +92,11 @@ def get_ground_truth_data(self) -> dict[int, dict[str, Any]]: ) track_id = int(float(data["id"])) + # If frame does not exist in dict: initialise if frame_number not in ground_truth_dict: ground_truth_dict[frame_number] = {"bbox": [], "id": []} + # Append bbox and id to the dictionary ground_truth_dict[frame_number]["bbox"].append(bbox) ground_truth_dict[frame_number]["id"].append(track_id) @@ -264,13 +254,13 @@ def count_identity_switches( # noqa: C901 return switch_counter - def evaluate_mota( + def compute_mota_one_frame( self, gt_data: dict[str, np.ndarray], pred_data: dict[str, np.ndarray], iou_threshold: float, gt_to_tracked_id_previous_frame: Optional[dict[int, int]], - ) -> tuple[float, dict[int, int]]: + ) -> tuple[float, int, int, int, int, int, dict[int, int]]: """Evaluate MOTA (Multiple Object Tracking Accuracy). Parameters @@ -301,11 +291,12 @@ def evaluate_mota( """ total_gt = len(gt_data["bbox"]) false_positive = 0 + true_positive = 0 indices_of_matched_gt_boxes = set() gt_to_tracked_id_current_frame = {} - pred_boxes = pred_data["bbox"] - pred_ids = pred_data["id"] + pred_boxes = pred_data["tracked_boxes"] + pred_ids = pred_data["ids"] gt_boxes = gt_data["bbox"] gt_ids = gt_data["id"] @@ -325,6 +316,7 @@ def evaluate_mota( index_gt_not_match = j if index_gt_best_match is not None: + true_positive += 1 # Successfully found a matching ground truth box for the # tracked box. indices_of_matched_gt_boxes.add(index_gt_best_match) @@ -347,8 +339,15 @@ def evaluate_mota( mota = ( 1 - (missed_detections + false_positive + num_switches) / total_gt ) - - return mota, gt_to_tracked_id_current_frame + return ( + mota, + true_positive, + missed_detections, + false_positive, + num_switches, + total_gt, + gt_to_tracked_id_current_frame, + ) def evaluate_tracking( self, @@ -364,7 +363,7 @@ def evaluate_tracking( frame, organized by frame number. predicted_dict : dict Dictionary containing predicted bounding boxes and IDs for each - frame, organized by frame number. + frame, organized by frame _index_. Returns ------- @@ -375,27 +374,55 @@ def evaluate_tracking( """ mota_values = [] prev_frame_id_map: Optional[dict] = None + results: dict[str, Any] = { + "Frame Number": [], + "Total Ground Truth": [], + "True Positives": [], + "Missed Detections": [], + "False Positives": [], + "Number of Switches": [], + "MOTA": [], + } for frame_number in sorted(ground_truth_dict.keys()): gt_data_frame = ground_truth_dict[frame_number] if frame_number < len(predicted_dict): pred_data_frame = predicted_dict[frame_number] - mota, prev_frame_id_map = self.evaluate_mota( + + ( + mota, + true_positives, + missed_detections, + false_positives, + num_switches, + total_gt, + prev_frame_id_map, + ) = self.compute_mota_one_frame( gt_data_frame, pred_data_frame, self.iou_threshold, prev_frame_id_map, ) mota_values.append(mota) + results["Frame Number"].append(frame_number) + results["Total Ground Truth"].append(total_gt) + results["True Positives"].append(true_positives) + results["Missed Detections"].append(missed_detections) + results["False Positives"].append(false_positives) + results["Number of Switches"].append(num_switches) + results["MOTA"].append(mota) + + save_tracking_mota_metrics(self.tracking_output_dir, results) return mota_values def run_evaluation(self) -> None: """Run evaluation of tracking based on tracking ground truth.""" - predicted_dict = self.get_predicted_data() ground_truth_dict = self.get_ground_truth_data() - mota_values = self.evaluate_tracking(ground_truth_dict, predicted_dict) + mota_values = self.evaluate_tracking( + ground_truth_dict, self.predicted_boxes_dict + ) overall_mota = np.mean(mota_values) logging.info("Overall MOTA: %f" % overall_mota) # noqa: UP031 diff --git a/crabs/tracker/track_video.py b/crabs/tracker/track_video.py index 7971ef3e..81ad2cef 100644 --- a/crabs/tracker/track_video.py +++ b/crabs/tracker/track_video.py @@ -2,8 +2,8 @@ import argparse import logging -import os import sys +from datetime import datetime from pathlib import Path import cv2 @@ -13,222 +13,305 @@ import yaml # type: ignore from crabs.detector.models import FasterRCNN +from crabs.detector.utils.evaluate import ( + get_config_from_ckpt, + get_mlflow_parameters_from_ckpt, +) from crabs.tracker.evaluate_tracker import TrackerEvaluate from crabs.tracker.sort import Sort from crabs.tracker.utils.io import ( - close_csv_file, - prep_csv_writer, - prep_video_writer, - release_video, - save_required_output, + generate_tracked_video, + open_video, + parse_video_frame_reading_error_and_log, + write_all_video_frames_as_images, + write_tracked_detections_to_csv, +) +from crabs.tracker.utils.tracking import ( + format_and_filter_bbox_predictions_for_sort, ) -from crabs.tracker.utils.tracking import prep_sort class Tracking: - """Interface for tracking crabs on a video using a trained detector. + """Interface for detecting and tracking crabs on a video. Parameters ---------- - args : argparse.Namespace) - Command-line arguments containing configuration settings. - - Attributes - ---------- args : argparse.Namespace - The command-line arguments provided. - video_path : str - The path to the input video. - sort_tracker : Sort - An instance of the sorting algorithm used for tracking. + Command-line arguments containing configuration settings. """ def __init__(self, args: argparse.Namespace) -> None: """Initialise the tracking interface with the given arguments.""" + # CLI inputs and config file self.args = args self.config_file = args.config_file - self.video_path = args.video_path - self.trained_model_path = self.args.trained_model_path - self.device = "cuda" if self.args.accelerator == "gpu" else "cpu" - - self.setup() - self.prep_outputs() + self.load_config_yaml() - self.sort_tracker = Sort( - max_age=self.config["max_age"], - min_hits=self.config["min_hits"], - iou_threshold=self.config["iou_threshold"], + # trained model data + self.trained_model_path = args.trained_model_path + trained_model_params = get_mlflow_parameters_from_ckpt( + self.trained_model_path + ) + # to log later in MLflow: + self.trained_model_run_name = trained_model_params["run_name"] + self.trained_model_expt_name = trained_model_params[ + "cli_args/experiment_name" + ] + self.trained_model_config = get_config_from_ckpt( + config_file=None, + trained_model_path=self.trained_model_path, ) - def setup(self): - """Load tracking config, trained model and input video path.""" + # input video data + self.input_video_path = args.video_path + self.input_video_file_root = f"{Path(self.input_video_path).stem}" + + # tracking output directory root name + self.tracking_output_dir_root = args.output_dir + self.frame_name_format_str = "frame_{frame_idx:08d}.png" + + # hardware + self.accelerator = "cuda" if args.accelerator == "gpu" else "cpu" + + # Prepare outputs: + # output directory, csv, and if required video and frames + self.prep_outputs() + + def load_config_yaml(self): + """Load yaml file that contains config parameters.""" with open(self.config_file) as f: self.config = yaml.safe_load(f) - # Get trained model - self.trained_model = FasterRCNN.load_from_checkpoint( - self.trained_model_path - ) - self.trained_model.eval() - self.trained_model.to(self.device) + def prep_outputs(self): + """Prepare output directory and file paths. - # Load the input video - self.video = cv2.VideoCapture(self.video_path) - if not self.video.isOpened(): - raise Exception("Error opening video file") - self.video_file_root = f"{Path(self.video_path).stem}" + This method: + - creates a timestamped directory to store the tracking output. + - sets the name of the output csv file for the tracked bounding boxes. + - sets up the output video path if required. + - sets up the frames subdirectory path if required. + """ + # Create output directory + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.tracking_output_dir = Path( + self.tracking_output_dir_root + f"_{timestamp}" + ) + self.tracking_output_dir.mkdir(parents=True, exist_ok=True) - def prep_outputs(self): - """Prepare csv writer and if required, video writer.""" - ( - self.csv_writer, - self.csv_file, - self.tracking_output_dir, - ) = prep_csv_writer(self.args.output_dir, self.video_file_root) + # Set name of output csv file + self.csv_file_path = str( + self.tracking_output_dir + / f"{self.input_video_file_root}_tracks.csv" + ) + # Set up output video path if required if self.args.save_video: - frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) - frame_height = int(self.video.get(cv2.CAP_PROP_FRAME_HEIGHT)) - cap_fps = self.video.get(cv2.CAP_PROP_FPS) + self.output_video_path = str( + self.tracking_output_dir + / f"{self.input_video_file_root}_tracks.mp4" + ) - self.video_output = prep_video_writer( - self.tracking_output_dir, - self.video_file_root, - frame_width, - frame_height, - cap_fps, + # Set up frames subdirectory path if required + if self.args.save_frames: + self.frames_subdir = ( + self.tracking_output_dir + / f"{self.input_video_file_root}_frames" ) - else: - self.video_output = None + self.frames_subdir.mkdir(parents=True, exist_ok=True) - def get_prediction(self, frame: np.ndarray) -> torch.Tensor: - """Get prediction from the trained model for a given frame. + def prep_detector_and_tracker(self): + """Prepare the trained detector and the tracker for inference.""" + # TODO: use Lightning's Trainer? - Parameters - ---------- - frame : np.ndarray - The input frame for which prediction is to be obtained. - - Returns - ------- - torch.Tensor: - The prediction tensor from the trained model. + # Load trained model + self.trained_model = FasterRCNN.load_from_checkpoint( + self.trained_model_path, + config=self.trained_model_config, # config of trained model! + ) + self.trained_model.eval() + self.trained_model.to(self.accelerator) - """ - transform = transforms.Compose( + # Define transforms to apply to input frames + self.inference_transforms = transforms.Compose( [ transforms.ToImage(), transforms.ToDtype(torch.float32, scale=True), ] ) - img = transform(frame).to(self.device) - img = img.unsqueeze(0) - with torch.no_grad(): - prediction = self.trained_model(img) - return prediction - def update_tracking(self, prediction: dict) -> np.ndarray: - """Update the tracking system with the latest prediction. + # Initialise SORT tracker + self.sort_tracker = Sort( + max_age=self.config["max_age"], + min_hits=self.config["min_hits"], + iou_threshold=self.config["iou_threshold"], + ) + + def run_tracking(self, prediction_dict: dict) -> np.ndarray: + """Update the tracker with the latest prediction. Parameters ---------- - prediction : dict - Dictionary containing predicted bounding boxes, scores, and labels. + prediction_dict : dict + Dictionary with data of the predicted bounding boxes. + The keys are: "boxes", "scores", and "labels". The labels + refer to the class of the object detected, and not its ID. Returns ------- np.ndarray: - tracked bounding boxes after updating the tracking system. + Array of tracked bounding boxes with object IDs added as the last + column. The shape of the array is (n, 5), where n is the number of + tracked boxes. The columns correspond to the values (xmin, ymin, + xmax, ymax, id). """ - pred_sort = prep_sort(prediction, self.config["score_threshold"]) - tracked_boxes_id_per_frame = self.sort_tracker.update(pred_sort) - self.tracked_bbox_id.append(tracked_boxes_id_per_frame) + # format predictions for SORT + prediction_tensor = format_and_filter_bbox_predictions_for_sort( + prediction_dict, self.config["score_threshold"] + ) + + # update tracked bboxes and append + tracked_boxes_id_per_frame = self.sort_tracker.update( + prediction_tensor.cpu() # move to CPU for SORT + ) return tracked_boxes_id_per_frame - def run_tracking(self): - """Run object detection + tracking on the video frames.""" - # If we pass ground truth: check the path exist - if self.args.annotations_file and not os.path.exists( - self.args.annotations_file - ): - logging.info( - f"Ground truth file {self.args.annotations_file} " - "does not exist." - "Exiting..." - ) - return + def run_detection(self, frame: np.ndarray) -> dict: + """Run detection on a single frame. - # initialisation - frame_idx = 0 - self.tracked_bbox_id = [] - - # Loop through frames of the video in batches - while self.video.isOpened(): - # Break if beyond end frame (mostly for debugging) - if ( - self.args.max_frames_to_read - and frame_idx + 1 > self.args.max_frames_to_read - ): - break + Returns + ------- + dict: + Dictionary with data of the predicted bounding boxes. + The keys are "boxes", "scores", and "labels". The labels + refer to the class of the object detected, and not its ID. + The data is stored as torch tensors. - # get total n frames - total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) + """ + # Apply transforms to frame and place tensor on devide + image_tensor = self.inference_transforms(frame).to(self.accelerator) - # read frame - ret, frame = self.video.read() - if not ret and (frame_idx == total_frames): - logging.info(f"All {total_frames} frames processed") - break - elif not ret: - logging.info( - f"Cannot read frame {frame_idx+1}/{total_frames}. " - "Exiting..." + # Add batch dimension + image_tensor = image_tensor.unsqueeze(0) + + # Run detection + with torch.no_grad(): + # use [0] to select the one image in the batch + detections_dict = self.trained_model(image_tensor)[0] + + return detections_dict + + def core_detection_and_tracking(self): + """Run detection and tracking loop through all video frames. + + Returns a dictionary with tracked bounding boxes per frame, and + with scores for each detection. + + Returns + ------- + dict: + A nested dictionary that maps frame indices (0-based) to a + dictionary with the following keys: + - "tracked_boxes", which contains the tracked bounding boxes as a + numpy array of shape (n, 5), where n is the number of tracked + boxes, and the 5 columns correspond to the values (xmin, ymin, + xmax, ymax, id). + - "scores", which contains the scores for each bounding box, + as a numpu array of shape (nboxes,) + + """ + # Initialise dict to store tracked bboxes + tracked_detections_all_frames = {} + + # Open input video + input_video_object = open_video(self.input_video_path) + total_n_frames = int(input_video_object.get(cv2.CAP_PROP_FRAME_COUNT)) + + # Loop over frames + frame_idx = 0 + while input_video_object.isOpened(): + # Read frame + ret, frame = input_video_object.read() + if not ret: + parse_video_frame_reading_error_and_log( + frame_idx, total_n_frames ) break - # predict bounding boxes - prediction = self.get_prediction(frame) - pred_scores = prediction[0]["scores"].detach().cpu().numpy() + # Run detection per frame + detections_dict = self.run_detection(frame) - # run tracking - tracked_boxes_id_per_frame = self.update_tracking(prediction) - save_required_output( - self.video_file_root, - self.args.save_frames, - self.tracking_output_dir, - self.csv_writer, - self.args.save_video, - self.video_output, - tracked_boxes_id_per_frame, - frame, - frame_idx + 1, - pred_scores, - ) + # Update tracking + tracked_boxes_array = self.run_tracking(detections_dict) - # update frame number + # Add data to dict; key is frame index (0-based) for input clip + tracked_detections_all_frames[frame_idx] = { + "tracked_boxes": tracked_boxes_array[:, :-1], + "ids": tracked_boxes_array[:, -1], # IDs are the last column + "scores": detections_dict["scores"], + } + + # Update frame index frame_idx += 1 + # Release video object + input_video_object.release() + + return tracked_detections_all_frames + + def detect_and_track_video(self) -> None: + """Run detection and tracking on input video.""" + # Prepare detector and tracker + # - Load trained model + # - Define transforms + # - Initialise SORT tracker + self.prep_detector_and_tracker() + + # Run detection and tracking over all frames in video + tracked_bboxes_dict = self.core_detection_and_tracking() + + # Write list of tracked bounding boxes to csv + write_tracked_detections_to_csv( + self.csv_file_path, + tracked_bboxes_dict, + frame_name_regexp=self.frame_name_format_str, + ) + + # Generate tracked video if required + # (it loops again thru frames) + if self.args.save_video: + generate_tracked_video( + self.input_video_path, + self.output_video_path, + tracked_bboxes_dict, + ) + logging.info(f"Tracked video saved to {self.output_video_path}") + + # Write frames if required + # (it loops again thru frames) + if self.args.save_frames: + write_all_video_frames_as_images( + self.input_video_path, + self.frames_subdir, + self.frame_name_format_str, + ) + logging.info( + "Input frames saved to " + f"{self.tracking_output_dir / self.frames_subdir}" + ) + + # Evaluate tracker if ground truth is passed if self.args.annotations_file: evaluation = TrackerEvaluate( self.args.annotations_file, - self.tracked_bbox_id, + tracked_bboxes_dict, self.config["iou_threshold"], + self.tracking_output_dir, ) evaluation.run_evaluation() - # Close input video - self.video.release() - - # Close outputs - if self.args.save_video: - release_video(self.video_output) - - if self.args.save_frames: - close_csv_file(self.csv_file) - def main(args) -> None: """Run detection+tracking inference on video. @@ -244,7 +327,8 @@ def main(args) -> None: """ inference = Tracking(args) - inference.run_tracking() + + inference.detect_and_track_video() def tracking_parse_args(args): @@ -263,13 +347,13 @@ def tracking_parse_args(args): help="Location of the video to be tracked.", ) parser.add_argument( - "--annotations_file", + "--config_file", type=str, - default=None, + default=str(Path(__file__).parent / "config" / "tracking_config.yaml"), help=( - "Location of JSON file containing ground truth annotations " - "(optional). " - "If passed, the evaluation metrics for the tracker are computed." + "Location of YAML config to control tracking. " + "Default: " + "crabs-exploration/crabs/tracking/config/tracking_config.yaml. " ), ) parser.add_argument( @@ -279,19 +363,14 @@ def tracking_parse_args(args): help=( "Root name of the directory to save the tracking output. " "The name of the output directory is appended with a timestamp. " + "The tracking output consist of a .csv. file named " + "_tracks.csv with the tracked bounding boxes. " + "Optionally, it can include a video file named " + "_tracks.mp4, and all frames from the video " + "under a _frames subdirectory. " "Default: ./tracking_output_. " ), ) - parser.add_argument( - "--config_file", - type=str, - default=str(Path(__file__).parent / "config" / "tracking_config.yaml"), - help=( - "Location of YAML config to control tracking. " - "Default: " - "crabs-exploration/crabs/tracking/config/tracking_config.yaml. " - ), - ) parser.add_argument( "--save_video", action="store_true", @@ -310,6 +389,16 @@ def tracking_parse_args(args): "support their visualisation and correction using the VIA tool. " ), ) + parser.add_argument( + "--annotations_file", + type=str, + default=None, + help=( + "Location of JSON file containing ground truth annotations " + "(optional). " + "If passed, the evaluation metrics for the tracker are computed." + ), + ) parser.add_argument( "--accelerator", type=str, diff --git a/crabs/tracker/utils/io.py b/crabs/tracker/utils/io.py index 16c82cda..d03cda65 100644 --- a/crabs/tracker/utils/io.py +++ b/crabs/tracker/utils/io.py @@ -1,45 +1,57 @@ """Utility functions for handling input and output operations.""" import csv -import os -from datetime import datetime +import logging from pathlib import Path import cv2 import numpy as np from crabs.detector.utils.visualization import draw_bbox -from crabs.tracker.utils.tracking import ( - save_output_frame, - write_tracked_bbox_to_csv, -) -def prep_csv_writer(output_dir: str, video_file_root: str): - """Prepare csv writer to output tracking results. +def open_video(video_path: str) -> cv2.VideoCapture: + """Open video file.""" + video_object = cv2.VideoCapture(video_path) + if not video_object.isOpened(): + raise Exception("Error opening video file") + return video_object - Parameters - ---------- - output_dir : str - The output folder where the output will be stored. - video_file_root : str - The root name of the video file. - Returns - ------- - Tuple - A tuple containing the CSV writer, the CSV file object, and the - tracking output directory path. +def get_video_parameters(video_path: str) -> dict: + """Get total number of frames, frame width and height, and fps of video.""" + # Open video + video_object = open_video(video_path) - """ - # Create a timestamped directory for the tracking output - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - tracking_output_dir = Path(output_dir + f"_{timestamp}") - tracking_output_dir.mkdir(parents=True, exist_ok=True) + # Get video parameters + video_parameters = {} + video_parameters["total_frames"] = int( + video_object.get(cv2.CAP_PROP_FRAME_COUNT) + ) + video_parameters["frame_width"] = int( + video_object.get(cv2.CAP_PROP_FRAME_WIDTH) + ) + video_parameters["frame_height"] = int( + video_object.get(cv2.CAP_PROP_FRAME_HEIGHT) + ) + video_parameters["fps"] = video_object.get(cv2.CAP_PROP_FPS) + # Release video object + video_object.release() + + return video_parameters + + +def write_tracked_detections_to_csv( + csv_file_path: str, + tracked_bboxes_dict: dict, + frame_name_regexp: str = "frame_{frame_idx:08d}.png", + all_frames_size: int = 8888, +): + """Write tracked detections to a csv file.""" # Initialise csv file csv_file = open( # noqa: SIM115 - f"{str(tracking_output_dir)}/{video_file_root}_tracks.csv", + csv_file_path, "w", ) csv_writer = csv.writer(csv_file) @@ -58,128 +70,181 @@ def prep_csv_writer(output_dir: str, video_file_root: str): ) ) - return csv_writer, csv_file, tracking_output_dir + # write detections + # loop thru frames + for frame_idx in tracked_bboxes_dict: + # loop thru all boxes in frame + for bbox, id, pred_score in zip( + tracked_bboxes_dict[frame_idx]["tracked_boxes"], + tracked_bboxes_dict[frame_idx]["ids"], + tracked_bboxes_dict[frame_idx]["scores"], + ): + # extract shape + xmin, ymin, xmax, ymax = bbox + width_box = int(xmax - xmin) + height_box = int(ymax - ymin) + + # Add to csv + csv_writer.writerow( + ( + frame_name_regexp.format( + frame_idx=frame_idx + ), # f"frame_{frame_idx:08d}.png", # frame index! + all_frames_size, # frame size + '{{"clip":{}}}'.format("123"), + 1, + 0, + f'{{"name":"rect","x":{xmin},"y":{ymin},"width":{width_box},"height":{height_box}}}', + f'{{"track":"{int(id)}", "confidence":"{pred_score}"}}', + ) + ) -def prep_video_writer( - output_dir: str, - video_file_root: str, - frame_width: int, - frame_height: int, - cap_fps: float, -) -> cv2.VideoWriter: - """Prepare video writer to output processed video. +def write_frame_to_output_video( + frame: np.ndarray, + tracked_bboxes_one_frame: dict, + output_video_object: cv2.VideoWriter, +) -> None: + """Write frame with tracked bounding boxes to output video.""" + frame_copy = frame.copy() # why copy? + for bbox, id in zip( + tracked_bboxes_one_frame["tracked_boxes"], + tracked_bboxes_one_frame["ids"], + ): + xmin, ymin, xmax, ymax = bbox + + draw_bbox( + frame_copy, + (xmin, ymin), + (xmax, ymax), + (0, 0, 255), + f"id : {int(id)}", + ) + output_video_object.write(frame_copy) - Parameters - ---------- - output_dir : str - The output folder where the output will be stored. - video_file_root :str - The root name of the video file. - frame_width : int - The width of the video frames. - frame_height : int - The height of the video frames. - cap_fps : float - The frames per second of the video. - - Returns - ------- - cv2.VideoWriter - The video writer object for writing video frames. - """ - output_file = os.path.join( - output_dir, - f"{video_file_root}_tracks.mp4", - ) +def parse_video_frame_reading_error_and_log(frame_idx: int, total_frames: int): + """Parse error message for reading a video frame.""" + if frame_idx == total_frames: + logging.info(f"All {total_frames} frames processed") + else: + logging.info( + f"Error reading frame index " f"{frame_idx}/{total_frames}." + ) + + +def setup_video_writer_from_input_video( + reference_video_path: str, output_video_path: str +) -> cv2.VideoWriter: + """Set up video writer with the same parameters as reference video.""" + input_video_params = get_video_parameters(reference_video_path) output_codec = cv2.VideoWriter_fourcc("m", "p", "4", "v") - video_output = cv2.VideoWriter( - output_file, output_codec, cap_fps, (frame_width, frame_height) + output_video_writer = cv2.VideoWriter( + output_video_path, + output_codec, + input_video_params["fps"], + ( + input_video_params["frame_width"], + input_video_params["frame_height"], + ), ) + return output_video_writer - return video_output +def generate_tracked_video( + input_video_path: str, output_video_path: str, tracked_bboxes: dict +): + """Generate tracked video.""" + # Open input video + input_video_object = open_video(input_video_path) -def save_required_output( - video_file_root: Path, - save_frames: bool, - tracking_output_dir: Path, - csv_writer: cv2.VideoWriter, - save_video: bool, - video_output: cv2.VideoWriter, - tracked_boxes: list[list[float]], - frame: np.ndarray, - frame_number: int, - pred_scores: np.ndarray, -) -> None: - """Handle the output based on argument options. - - Parameters - ---------- - video_file_root : Path - The root name of the video file. - save_frames : bool - Flag to save frames. - tracking_output_dir : Path - Directory to save tracking output. - csv_writer : Any - CSV writer object. - save_video : bool - Flag to save video. - video_output : cv2.VideoWriter - Video writer object for writing video frames. - tracked_boxes : list[list[float]] - List of tracked bounding boxes. - frame : np.ndarray - The current frame. - frame_number : int - The frame number. - pred_scores : np.ndarray - The prediction score from detector + # Set up output video writer following input video parameters + output_video_writer = setup_video_writer_from_input_video( + input_video_path, output_video_path + ) - """ - frame_name = f"frame_{frame_number:08d}.png" + # Loop over frames + frame_idx = 0 + while input_video_object.isOpened(): + # Read frame + ret, frame = input_video_object.read() + if not ret: + parse_video_frame_reading_error_and_log( + frame_idx, + int(input_video_object.get(cv2.CAP_PROP_FRAME_COUNT)), + ) + break - for bbox, pred_score in zip(tracked_boxes, pred_scores): - write_tracked_bbox_to_csv( - np.array(bbox), frame, frame_name, csv_writer, pred_score + # Write frame to output video + write_frame_to_output_video( + frame, + tracked_bboxes[frame_idx], + output_video_writer, ) - if save_frames: - # create subdirectory of frames - frames_subdir = tracking_output_dir / f"{video_file_root}_frames" - frames_subdir.mkdir(parents=True, exist_ok=True) + frame_idx += 1 - # save frame (without bounding boxes) - save_output_frame( - frame_name, - frames_subdir, - frame, - frame_number, - ) + # Release video objects + input_video_object.release() + output_video_writer.release() + cv2.destroyAllWindows() - if save_video: - frame_copy = frame.copy() - for bbox in tracked_boxes: - xmin, ymin, xmax, ymax, id = bbox - draw_bbox( - frame_copy, - (xmin, ymin), - (xmax, ymax), - (0, 0, 255), - f"id : {int(id)}", - ) - video_output.write(frame_copy) +def write_frame_as_image(frame: np.ndarray, frame_path: str): + """Write frame as image file.""" + img_saved = cv2.imwrite( + frame_path, + frame, + ) + if not img_saved: + logging.info(f"Error saving {frame_path}.") + + +def write_all_video_frames_as_images( + input_video_path: str, + frames_subdir: Path, + frame_name_format_str: str = "frame_{frame_idx:08d}.png", +): + """Save frames of input video as image files. + + Parameters + ---------- + input_video_path : str + The path to the input video. + frames_subdir : Path + The directory to save frames. + frame_name_format_str : str + The format to follow for the frame filenames. + E.g. "frame_{frame_idx:08d}.png" -def close_csv_file(csv_file) -> None: - """Close the CSV file if it's open.""" - if csv_file: - csv_file.close() + """ + # Open input video + input_video_object = open_video(input_video_path) + + # Loop over frames + frame_idx = 0 + while input_video_object.isOpened(): + # Read frame + ret, frame = input_video_object.read() + if not ret: + parse_video_frame_reading_error_and_log( + frame_idx, + int(input_video_object.get(cv2.CAP_PROP_FRAME_COUNT)), + ) + break + + # Write frame to file + write_frame_as_image( + frame, + str( + frames_subdir + / frame_name_format_str.format(frame_idx=frame_idx) + ), + ) + # Update frame index + frame_idx += 1 -def release_video(video_output) -> None: - """Release the video file if it's open.""" - if video_output: - video_output.release() + # Release video objects + input_video_object.release() + cv2.destroyAllWindows() diff --git a/crabs/tracker/utils/tracking.py b/crabs/tracker/utils/tracking.py index ca83ea94..514e65ad 100644 --- a/crabs/tracker/utils/tracking.py +++ b/crabs/tracker/utils/tracking.py @@ -1,12 +1,47 @@ """Utility functions for tracking.""" import json -import logging from pathlib import Path from typing import Any -import cv2 -import numpy as np +import pandas as pd +import torch + + +def format_and_filter_bbox_predictions_for_sort( + prediction_dict: dict, score_threshold: float +) -> torch.Tensor: + """Put predictions in format expected by SORT. + + Lower confidence predictions are filtered out. + + Parameters + ---------- + prediction_dict : dict + Dictionary containing predicted bounding boxes, scores, + and labels. + + score_threshold : float + The threshold score for filtering out low-confidence predictions. + + Returns + ------- + torch.Tensor: + A torch tensor of shape (N, 5) representing N detection bounding boxes + in format [xmin, ymin, xmax, ymax, score]. + + """ + # Format as a tensor with scores as last column + predictions_tensor = torch.hstack( + ( + prediction_dict["boxes"], + prediction_dict["scores"].unsqueeze(dim=1), + ) + ) + + # Filter rows in tensor based on last column + # if pred_score > score_threshold: + return predictions_tensor[predictions_tensor[:, -1] > score_threshold] def extract_bounding_box_info(row: list[str]) -> dict[str, Any]: @@ -45,107 +80,11 @@ def extract_bounding_box_info(row: list[str]) -> dict[str, Any]: } -def write_tracked_bbox_to_csv( - bbox: np.ndarray, - frame: np.ndarray, - frame_name: str, - csv_writer: Any, - pred_score: np.ndarray, -) -> None: - """Write bounding box annotation to a CSV file. - - Parameters - ---------- - bbox : np.ndarray - A numpy array containing the bounding box coordinates - (xmin, ymin, xmax, ymax, id). - frame : np.ndarray - The frame to which the bounding box belongs. - frame_name : str - The name of the frame. - csv_writer : Any - The CSV writer object to write the annotation. - pred_score : np.ndarray - The prediction score from detector. - - """ - # Bounding box geometry - xmin, ymin, xmax, ymax, id = bbox - width_box = int(xmax - xmin) - height_box = int(ymax - ymin) - - # Add to csv - csv_writer.writerow( - ( - frame_name, - frame.size, - '{{"clip":{}}}'.format("123"), - 1, - 0, - f'{{"name":"rect","x":{xmin},"y":{ymin},"width":{width_box},"height":{height_box}}}', - f'{{"track":"{int(id)}", "confidence":"{pred_score}"}}', - ) - ) - - -def save_output_frame( - frame_name: str, +def save_tracking_mota_metrics( tracking_output_dir: Path, - frame: np.ndarray, - frame_number: int, + track_results: dict[str, Any], ) -> None: - """Save tracked bounding boxes as frames. - - Parameters - ---------- - frame_name : str - The name of the image file to save frame in. - tracking_output_dir : Path - The directory where tracked frames and CSV file will be saved. - frame : np.ndarray - The frame image. - frame_number : int - The frame number. - - Returns - ------- - None - - """ - # Save frame as PNG - frame_path = tracking_output_dir / frame_name - img_saved = cv2.imwrite(str(frame_path), frame) - if not img_saved: - logging.error( - f"Didn't save {frame_name}, frame {frame_number}, Skipping." - ) - - -def prep_sort(prediction: dict, score_threshold: float) -> np.ndarray: - """Put predictions in format expected by SORT. - - Parameters - ---------- - prediction : dict - The dictionary containing predicted bounding boxes, scores, and labels. - - score_threshold : float - The threshold score for filtering out low-confidence predictions. - - Returns - ------- - np.ndarray: - An array containing sorted bounding boxes of detected objects. - - """ - pred_boxes = prediction[0]["boxes"].detach().cpu().numpy() - pred_scores = prediction[0]["scores"].detach().cpu().numpy() - pred_labels = prediction[0]["labels"].detach().cpu().numpy() - - pred_sort = [] - for box, score, _label in zip(pred_boxes, pred_scores, pred_labels): - if score > score_threshold: - bbox = np.concatenate((box, [score])) - pred_sort.append(bbox) - - return np.asarray(pred_sort) + """Save tracking metrics to a CSV file.""" + track_df = pd.DataFrame(track_results) + output_filename = f"{tracking_output_dir}/tracking_metrics_output.csv" + track_df.to_csv(output_filename, index=False) diff --git a/tests/test_unit/test_evaluate_tracker.py b/tests/test_unit/test_evaluate_tracker.py index a08ec8d0..aa6162e6 100644 --- a/tests/test_unit/test_evaluate_tracker.py +++ b/tests/test_unit/test_evaluate_tracker.py @@ -7,22 +7,30 @@ @pytest.fixture -def evaluation(): - test_csv_file = Path(__file__).parents[1] / "data" / "gt_test.csv" +def tracker_evaluate_interface(): + annotations_file_csv = Path(__file__).parents[1] / "data" / "gt_test.csv" return TrackerEvaluate( - test_csv_file, predicted_boxes_id=[], iou_threshold=0.1 + annotations_file_csv, + predicted_boxes_dict={}, + iou_threshold=0.1, + tracking_output_dir="/path/output", ) -def test_get_ground_truth_data(evaluation): - ground_truth_dict = evaluation.get_ground_truth_data() +def test_get_ground_truth_data_structure(tracker_evaluate_interface): + """Test the loaded ground truth data has the expected structure.""" + # Get ground truth data dict + ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data() + # check type assert isinstance(ground_truth_dict, dict) + # check it is a nested dictionary assert all( isinstance(frame_data, dict) for frame_data in ground_truth_dict.values() ) + # check data types for values in nested dictionary for frame_number, data in ground_truth_dict.items(): assert isinstance(frame_number, int) assert isinstance(data["bbox"], np.ndarray) @@ -30,7 +38,9 @@ def test_get_ground_truth_data(evaluation): assert data["bbox"].shape[1] == 4 -def test_ground_truth_data_from_csv(evaluation): +def test_ground_truth_data_values(tracker_evaluate_interface): + """Test ground truth data holds expected values.""" + # Define expected ground truth data expected_data = { 11: { "bbox": np.array( @@ -50,25 +60,33 @@ def test_ground_truth_data_from_csv(evaluation): }, } - ground_truth_dict = evaluation.get_ground_truth_data() + # Get ground truth data dict + ground_truth_dict = tracker_evaluate_interface.get_ground_truth_data() - for frame_number, expected_frame_data in expected_data.items(): - assert frame_number in ground_truth_dict + # Check if ground truth data matches expected values + for expected_frame_number, expected_frame_data in expected_data.items(): + # check expected key is present + assert expected_frame_number in ground_truth_dict - assert len(ground_truth_dict[frame_number]["bbox"]) == len( + # check n of bounding boxes per frame matches the expected value + assert len(ground_truth_dict[expected_frame_number]["bbox"]) == len( expected_frame_data["bbox"] ) + + # check bbox arrays match the expected values for bbox, expected_bbox in zip( - ground_truth_dict[frame_number]["bbox"], + ground_truth_dict[expected_frame_number]["bbox"], expected_frame_data["bbox"], ): assert np.allclose( bbox, expected_bbox - ), f"Frame {frame_number}, bbox mismatch" + ), f"Frame {expected_frame_number}, bbox mismatch" + # check id arrays match the expected values assert np.array_equal( - ground_truth_dict[frame_number]["id"], expected_frame_data["id"] - ), f"Frame {frame_number}, id mismatch" + ground_truth_dict[expected_frame_number]["id"], + expected_frame_data["id"], + ), f"Frame {expected_frame_number}, id mismatch" @pytest.mark.parametrize( @@ -220,11 +238,19 @@ def test_ground_truth_data_from_csv(evaluation): ], ) def test_count_identity_switches( - evaluation, prev_frame_id_map, current_frame_id_map, expected_output + tracker_evaluate_interface, + prev_frame_id_map, + current_frame_id_map, + expected_output, ): - evaluation.last_known_predicted_ids = {1: 11, 2: 12, 3: 13, 4: 14} + tracker_evaluate_interface.last_known_predicted_ids = { + 1: 11, + 2: 12, + 3: 13, + 4: 14, + } assert ( - evaluation.count_identity_switches( + tracker_evaluate_interface.count_identity_switches( prev_frame_id_map, current_frame_id_map ) == expected_output @@ -240,18 +266,18 @@ def test_count_identity_switches( ([0, 0, 10, 10], [5, 15, 15, 25], 0.0), ], ) -def test_calculate_iou(box1, box2, expected_iou, evaluation): +def test_calculate_iou(box1, box2, expected_iou, tracker_evaluate_interface): box1 = np.array(box1) box2 = np.array(box2) - iou = evaluation.calculate_iou(box1, box2) + iou = tracker_evaluate_interface.calculate_iou(box1, box2) # Check if IoU matches expected value assert iou == pytest.approx(expected_iou, abs=1e-2) @pytest.mark.parametrize( - "gt_data, pred_data, prev_frame_id_map, expected_mota", + "gt_data, pred_data, prev_frame_id_map, expected_output", [ # perfect tracking ( @@ -266,17 +292,17 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 12, 13]), + "ids": np.array([11, 12, 13]), }, {1: 11, 2: 12, 3: 13}, - 1.0, + [1.0, 3, 0, 0, 0], ), ( { @@ -290,17 +316,17 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 12, 13]), + "ids": np.array([11, 12, 13]), }, {1: 11, 12: 2, 3: np.nan}, - 1.0, + [1.0, 3, 0, 0, 0], ), # ID switch ( @@ -315,17 +341,17 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 12, 14]), + "ids": np.array([11, 12, 14]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 3, 0, 0, 1], ), # missed detection ( @@ -340,13 +366,13 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 4]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [[10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0]] ), - "id": np.array([11, 12]), + "ids": np.array([11, 12]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 2, 1, 0, 0], ), # false positive ( @@ -361,7 +387,7 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], @@ -369,10 +395,10 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): [70.0, 70.0, 80.0, 80.0], ] ), - "id": np.array([11, 12, 13, 14]), + "ids": np.array([11, 12, 13, 14]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 3, 0, 1, 0], ), # low IOU and ID switch ( @@ -387,17 +413,17 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 30.0, 30.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 12, 14]), + "ids": np.array([11, 12, 14]), }, {1: 11, 2: 12, 3: 13}, - 0, + [0, 2, 1, 1, 1], ), # low IOU and ID switch on same box ( @@ -412,17 +438,17 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 30.0, 30.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 14, 13]), + "ids": np.array([11, 14, 13]), }, {1: 11, 2: 12, 3: 13}, - 1 / 3, + [1 / 3, 2, 1, 1, 0], ), # current tracked id = prev tracked id, but prev_gt_id != current gt id ( @@ -437,17 +463,17 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 4]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 12, 13]), + "ids": np.array([11, 12, 13]), }, {1: 11, 2: 12, 3: 13}, - 2 / 3, + [2 / 3, 3, 0, 0, 1], ), # ID swapped ( @@ -462,31 +488,44 @@ def test_calculate_iou(box1, box2, expected_iou, evaluation): "id": np.array([1, 2, 3]), }, { - "bbox": np.array( + "tracked_boxes": np.array( [ [10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0], [50.0, 50.0, 60.0, 60.0], ] ), - "id": np.array([11, 13, 12]), + "ids": np.array([11, 13, 12]), }, {1: 11, 2: 12, 3: 13}, - 1 / 3, + [1 / 3, 3, 0, 0, 2], ), ], ) -def test_evaluate_mota( +def test_compute_mota_one_frame( gt_data, pred_data, prev_frame_id_map, - expected_mota, - evaluation, + expected_output, + tracker_evaluate_interface, ): - mota, _ = evaluation.evaluate_mota( + ( + mota, + true_positives, + missed_detections, + false_positives, + num_switches, + total_gt, + _, + ) = tracker_evaluate_interface.compute_mota_one_frame( gt_data, pred_data, 0.1, # iou_threshold prev_frame_id_map, ) - assert mota == pytest.approx(expected_mota) + assert mota == pytest.approx(expected_output[0]) + assert true_positives == expected_output[1] + assert missed_detections == expected_output[2] + assert false_positives == expected_output[3] + assert num_switches == expected_output[4] + assert total_gt == (true_positives + missed_detections) diff --git a/tests/test_unit/test_track_video.py b/tests/test_unit/test_track_video.py index 7d7ffa89..3614d6bf 100644 --- a/tests/test_unit/test_track_video.py +++ b/tests/test_unit/test_track_video.py @@ -10,16 +10,17 @@ @pytest.fixture def mock_args(): - temp_dir = tempfile.mkdtemp() + tmp_dir = tempfile.mkdtemp() return Namespace( config_file="/path/to/config.yaml", video_path="/path/to/video.mp4", - trained_model_path="/path/to/model.ckpt", - output_dir=temp_dir, + trained_model_path="path/to/model.ckpt", + output_dir=tmp_dir, accelerator="gpu", annotations_file=None, save_video=None, + save_frames=None, ) @@ -28,33 +29,57 @@ def mock_args(): new_callable=mock_open, read_data="max_age: 10\nmin_hits: 3\niou_threshold: 0.1", ) -@patch("yaml.safe_load") @patch("cv2.VideoCapture") -@patch("crabs.tracker.track_video.FasterRCNN.load_from_checkpoint") -@patch("crabs.tracker.track_video.Sort") -def test_tracking_setup( - mock_sort, - mock_load_from_checkpoint, - mock_videocapture, +@patch("crabs.tracker.utils.io.get_video_parameters") +@patch("crabs.tracker.track_video.get_config_from_ckpt") +@patch("crabs.tracker.track_video.get_mlflow_parameters_from_ckpt") +# we patch where the function is looked at, see +# https://docs.python.org/3/library/unittest.mock.html#where-to-patch +@patch("yaml.safe_load") +def test_tracking_constructor( mock_yaml_load, + mock_get_mlflow_parameters_from_ckpt, + mock_get_config_from_ckpt, + mock_get_video_parameters, + mock_videocapture, mock_open, mock_args, ): + # mock reading tracking config from file mock_yaml_load.return_value = { "max_age": 10, "min_hits": 3, "iou_threshold": 0.1, } - mock_model = MagicMock() - mock_load_from_checkpoint.return_value = mock_model + # mock getting mlflow parameters from checkpoint + mock_get_mlflow_parameters_from_ckpt.return_value = { + "run_name": "trained_model_run_name", + "cli_args/experiment_name": "trained_model_expt_name", + } + + # mock getting trained model's config + mock_get_config_from_ckpt.return_value = {} + + # mock getting video parameters + mock_get_video_parameters.return_value = { + "total_frames": 614, + "frame_width": 1920, + "frame_height": 1080, + "fps": 60, + } + # mock input video as if opened correctly mock_video_capture = MagicMock() mock_video_capture.isOpened.return_value = True mock_videocapture.return_value = mock_video_capture + # instantiate tracking interface tracker = Tracking(mock_args) + # check output dir is created correctly + # TODO: add asserts for other attributes assigned in constructor assert tracker.args.output_dir == mock_args.output_dir + # delete output dir Path(mock_args.output_dir).rmdir() diff --git a/tests/test_unit/test_tracking_io.py b/tests/test_unit/test_tracking_io.py new file mode 100644 index 00000000..4031c6d1 --- /dev/null +++ b/tests/test_unit/test_tracking_io.py @@ -0,0 +1,92 @@ +import csv + +import numpy as np + +from crabs.tracker.utils.io import write_tracked_detections_to_csv + + +def test_write_tracked_detections_to_csv(tmp_path): + # Create test data + csv_file_path = tmp_path / "test_output.csv" + + # Create dictionary with tracked bounding boxes for 2 frames + tracked_bboxes_dict = {} + + # frame_idx = 0 + tracked_bboxes_dict[0] = { + "tracked_boxes": np.array([[10, 20, 30, 40], [50, 60, 70, 80]]), + "ids": np.array([1, 2]), + "scores": np.array([0.9, 0.8]), + } + + # frame_idx = 1 + tracked_bboxes_dict[1] = { + "tracked_boxes": np.array([[15, 25, 35, 45]]), + "ids": np.array([1]), + "scores": np.array([0.85]), + } + frame_name_regexp = "frame_{frame_idx:08d}.png" + all_frames_size = 8888 + + # Call function + write_tracked_detections_to_csv( + csv_file_path, + tracked_bboxes_dict, + frame_name_regexp, + all_frames_size, + ) + + # Read csv file + with open(csv_file_path, newline="") as csvfile: + csv_reader = csv.reader(csvfile) + rows = list(csv_reader) + + # Expected header + expected_header = [ + "filename", + "file_size", + "file_attributes", + "region_count", + "region_id", + "region_shape_attributes", + "region_attributes", + ] + + # Expected rows + expected_rows = [ + expected_header, + [ + "frame_00000000.png", + "8888", + '{"clip":123}', + "1", + "0", + '{"name":"rect","x":10,"y":20,"width":20,"height":20}', + '{"track":"1", "confidence":"0.9"}', + ], + [ + "frame_00000000.png", + "8888", + '{"clip":123}', + "1", + "0", + '{"name":"rect","x":50,"y":60,"width":20,"height":20}', + '{"track":"2", "confidence":"0.8"}', + ], + [ + "frame_00000001.png", + "8888", + '{"clip":123}', + "1", + "0", + '{"name":"rect","x":15,"y":25,"width":20,"height":20}', + '{"track":"1", "confidence":"0.85"}', + ], + ] + + # Assert the header + assert rows[0] == expected_header + + # Assert the rows + for i, expected_row in enumerate(expected_rows[1:], start=1): + assert rows[i] == expected_row diff --git a/tests/test_unit/test_tracking_utils.py b/tests/test_unit/test_tracking_utils.py index 3550f134..f066331a 100644 --- a/tests/test_unit/test_tracking_utils.py +++ b/tests/test_unit/test_tracking_utils.py @@ -1,12 +1,9 @@ -import csv -import io - -import numpy as np import pytest +import torch from crabs.tracker.utils.tracking import ( extract_bounding_box_info, - write_tracked_bbox_to_csv, + format_and_filter_bbox_predictions_for_sort, ) @@ -35,32 +32,50 @@ def test_extract_bounding_box_info(): assert result == expected_result -@pytest.fixture -def csv_output(): - return io.StringIO() - - -@pytest.fixture -def csv_writer(csv_output): - return csv.writer(csv_output) - - -def test_write_tracked_bbox_to_csv(csv_writer, csv_output): - bbox = np.array([10, 20, 50, 80, 1]) - frame = np.zeros((100, 100, 3), dtype=np.uint8) - frame_name = "frame_0001.png" - pred_score = 0.900 +@pytest.mark.parametrize( + "score_threshold, expected_output", + [ + ( + 0.5, + torch.tensor( + [ + [10, 20, 30, 40, 0.9], + [50, 60, 70, 80, 0.85], + [15, 25, 35, 45, 0.8], + ] + ), + ), + ( + 0.83, + torch.tensor( + [ + [10, 20, 30, 40, 0.9], + [50, 60, 70, 80, 0.85], + ] + ), + ), + ( + 0.95, + torch.empty((0, 5)), + ), + ], +) +def test_format_bbox_predictions_for_sort(score_threshold, expected_output): + # Define the test data + prediction = { + "boxes": torch.tensor( + [[10, 20, 30, 40], [50, 60, 70, 80], [15, 25, 35, 45]] + ), + "scores": torch.tensor([0.9, 0.85, 0.8]), + } - write_tracked_bbox_to_csv(bbox, frame, frame_name, csv_writer, pred_score) + # Call the function + result = format_and_filter_bbox_predictions_for_sort( + prediction, score_threshold + ) - expected_row = ( - "frame_0001.png", - 30000, - '"{""clip"":123}"', - 1, - 0, - '"{""name"":""rect"",""x"":10,""y"":20,""width"":40,""height"":60}"', - '"{""track"":""1"", ""confidence"":""0.9""}"', + # Assert the result + ( + torch.testing.assert_close(result, expected_output), + f"Expected {expected_output}, but got {result}", ) - expected_row_str = ",".join(map(str, expected_row)) - assert csv_output.getvalue().strip() == expected_row_str