Skip to content

Commit

Permalink
Add setup_trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Nov 5, 2024
1 parent abaeb46 commit 18ad17f
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion crabs/tracker/track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path

import cv2
import lightning
import numpy as np
import torch
import torchvision.transforms.v2 as transforms
Expand All @@ -16,6 +17,7 @@
from crabs.detector.utils.detection import (
log_mlflow_metadata_as_info,
set_mlflow_run_name,
setup_mlflow_logger,
)
from crabs.tracker.evaluate_tracker import TrackerEvaluate
from crabs.tracker.sort import Sort
Expand Down Expand Up @@ -73,6 +75,33 @@ def __init__(self, args: argparse.Namespace) -> None:
# Log MLflow information to screen
log_mlflow_metadata_as_info(self)

def setup_trainer(self):
"""Set up trainer object with logging for testing."""
# Setup logger
mlf_logger = setup_mlflow_logger(
experiment_name=self.experiment_name,
run_name=self.run_name,
mlflow_folder=self.mlflow_folder,
cli_args=self.args,
)

# Add trained model section to MLflow hyperparameters
mlf_logger.log_hyperparams(
{
"trained_model/experiment_name": self.trained_model_expt_name,
"trained_model/run_name": self.trained_model_run_name,
"trained_model/ckpt_file": Path(self.trained_model_path).name,
}
)

# Add other unlogged information from init?

# Return trainer linked to logger
return lightning.Trainer(
accelerator=self.accelerator, # lightning accelerators
logger=mlf_logger,
)

def setup(self):
"""Load tracking config, trained model and input video path."""
with open(self.config_file) as f:
Expand Down Expand Up @@ -334,10 +363,11 @@ def tracking_parse_args(args):
parser.add_argument(
"--experiment_name",
type=str,
default="Inference",
help=(
"Name of the experiment in MLflow, under which the current run "
"will be logged. "
"By default: <trained_model_mlflow_experiment_name>_evaluation."
"By default: Inference."
),
)
parser.add_argument(
Expand Down

0 comments on commit 18ad17f

Please sign in to comment.