forked from ultralytics/ultralytics
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ultralytics 8.0.77
Ray[Tune] for hyperparameter optimization (ultra…
…lytics#2014) Co-authored-by: JF Chen <[email protected]> Co-authored-by: Ayush Chaurasia <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
4916014
commit 5065ca3
Showing
12 changed files
with
205 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
try: | ||
import ray | ||
from ray import tune | ||
from ray.air import session | ||
except (ImportError, AssertionError): | ||
tune = None | ||
|
||
|
||
def on_fit_epoch_end(trainer): | ||
if ray.tune.is_session_enabled(): | ||
metrics = trainer.metrics | ||
metrics['epoch'] = trainer.epoch | ||
session.report(metrics) | ||
|
||
|
||
callbacks = { | ||
'on_fit_epoch_end': on_fit_epoch_end, } if tune else {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
# Ultralytics YOLO 🚀, GPL-3.0 license | ||
|
||
from ultralytics.yolo.utils.torch_utils import get_flops, get_num_params | ||
|
||
try: | ||
import wandb as wb | ||
|
||
assert hasattr(wb, '__version__') | ||
except (ImportError, AssertionError): | ||
wb = None | ||
|
||
|
||
def on_pretrain_routine_start(trainer): | ||
wb.init(project=trainer.args.project or 'YOLOv8', name=trainer.args.name, config=vars( | ||
trainer.args)) if not wb.run else wb.run | ||
|
||
|
||
def on_fit_epoch_end(trainer): | ||
wb.run.log(trainer.metrics, step=trainer.epoch + 1) | ||
if trainer.epoch == 0: | ||
model_info = { | ||
'model/parameters': get_num_params(trainer.model), | ||
'model/GFLOPs': round(get_flops(trainer.model), 3), | ||
'model/speed(ms)': round(trainer.validator.speed['inference'], 3)} | ||
wb.run.log(model_info, step=trainer.epoch + 1) | ||
|
||
|
||
def on_train_epoch_end(trainer): | ||
wb.run.log(trainer.label_loss_items(trainer.tloss, prefix='train'), step=trainer.epoch + 1) | ||
wb.run.log(trainer.lr, step=trainer.epoch + 1) | ||
if trainer.epoch == 1: | ||
wb.run.log({f.stem: wb.Image(str(f)) | ||
for f in trainer.save_dir.glob('train_batch*.jpg')}, | ||
step=trainer.epoch + 1) | ||
|
||
|
||
def on_train_end(trainer): | ||
art = wb.Artifact(type='model', name=f'run_{wb.run.id}_model') | ||
if trainer.best.exists(): | ||
art.add_file(trainer.best) | ||
wb.run.log_artifact(art) | ||
|
||
|
||
callbacks = { | ||
'on_pretrain_routine_start': on_pretrain_routine_start, | ||
'on_train_epoch_end': on_train_epoch_end, | ||
'on_fit_epoch_end': on_fit_epoch_end, | ||
'on_train_end': on_train_end} if wb else {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from ultralytics.yolo.utils import LOGGER | ||
|
||
try: | ||
from ray import tune | ||
from ray.air import RunConfig, session # noqa | ||
from ray.air.integrations.wandb import WandbLoggerCallback # noqa | ||
from ray.tune.schedulers import ASHAScheduler # noqa | ||
from ray.tune.schedulers import AsyncHyperBandScheduler as AHB # noqa | ||
|
||
except ImportError: | ||
LOGGER.info("Tuning hyperparameters requires ray/tune. Install using `pip install 'ray[tune]'`") | ||
tune = None | ||
|
||
default_space = { | ||
# 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'RMSProp']), | ||
'lr0': tune.uniform(1e-5, 1e-1), | ||
'lrf': tune.uniform(0.01, 1.0), # final OneCycleLR learning rate (lr0 * lrf) | ||
'momentum': tune.uniform(0.6, 0.98), # SGD momentum/Adam beta1 | ||
'weight_decay': tune.uniform(0.0, 0.001), # optimizer weight decay 5e-4 | ||
'warmup_epochs': tune.uniform(0.0, 5.0), # warmup epochs (fractions ok) | ||
'warmup_momentum': tune.uniform(0.0, 0.95), # warmup initial momentum | ||
'box': tune.uniform(0.02, 0.2), # box loss gain | ||
'cls': tune.uniform(0.2, 4.0), # cls loss gain (scale with pixels) | ||
'fl_gamma': tune.uniform(0.0, 2.0), # focal loss gamma (efficientDet default gamma=1.5) | ||
'hsv_h': tune.uniform(0.0, 0.1), # image HSV-Hue augmentation (fraction) | ||
'hsv_s': tune.uniform(0.0, 0.9), # image HSV-Saturation augmentation (fraction) | ||
'hsv_v': tune.uniform(0.0, 0.9), # image HSV-Value augmentation (fraction) | ||
'degrees': tune.uniform(0.0, 45.0), # image rotation (+/- deg) | ||
'translate': tune.uniform(0.0, 0.9), # image translation (+/- fraction) | ||
'scale': tune.uniform(0.0, 0.9), # image scale (+/- gain) | ||
'shear': tune.uniform(0.0, 10.0), # image shear (+/- deg) | ||
'perspective': tune.uniform(0.0, 0.001), # image perspective (+/- fraction), range 0-0.001 | ||
'flipud': tune.uniform(0.0, 1.0), # image flip up-down (probability) | ||
'fliplr': tune.uniform(0.0, 1.0), # image flip left-right (probability) | ||
'mosaic': tune.uniform(0.0, 1.0), # image mixup (probability) | ||
'mixup': tune.uniform(0.0, 1.0), # image mixup (probability) | ||
'copy_paste': tune.uniform(0.0, 1.0)} # segment copy-paste (probability) | ||
|
||
task_metric_map = { | ||
'detect': 'metrics/mAP50-95(B)', | ||
'segment': 'metrics/mAP50-95(M)', | ||
'classify': 'top1_acc', | ||
'pose': None} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters