Skip to content

Commit

Permalink
Update Validator to use model argument (ultralytics#4480)
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher authored Aug 21, 2023
1 parent 615ddc9 commit b2f279f
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 14 deletions.
3 changes: 2 additions & 1 deletion ultralytics/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def cfg2dict(cfg):
Convert a configuration object to a dictionary, whether it is a file path, a string, or a SimpleNamespace object.
Args:
cfg (str | Path | SimpleNamespace): Configuration object to be converted to a dictionary.
cfg (str | Path | dict | SimpleNamespace): Configuration object to be converted to a dictionary.
Returns:
cfg (dict): Configuration object in dictionary format.
Expand Down Expand Up @@ -110,6 +110,7 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
# Merge overrides
if overrides:
overrides = cfg2dict(overrides)
overrides.pop('save_dir', None) # special override keys to ignore
check_dict_alignment(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)

Expand Down
9 changes: 4 additions & 5 deletions ultralytics/engine/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from ultralytics.cfg import get_cfg
from ultralytics.data.utils import check_cls_dataset, check_det_dataset
from ultralytics.nn.autobackend import AutoBackend
from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils import LOGGER, RANK, SETTINGS, TQDM_BAR_FORMAT, callbacks, colorstr, emojis
from ultralytics.utils.checks import check_imgsz
from ultralytics.utils.files import increment_path
from ultralytics.utils.ops import Profile
Expand All @@ -43,9 +43,9 @@ class BaseValidator:
A base class for creating validators.
Attributes:
args (SimpleNamespace): Configuration for the validator.
dataloader (DataLoader): Dataloader to use for validation.
pbar (tqdm): Progress bar to update during validation.
args (SimpleNamespace): Configuration for the validator.
model (nn.Module): Model to validate.
data (dict): Data dictionary.
device (torch.device): Device to use for validation.
Expand Down Expand Up @@ -76,9 +76,9 @@ def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callba
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
"""
self.args = get_cfg(overrides=args)
self.dataloader = dataloader
self.pbar = pbar
self.args = args or get_cfg(DEFAULT_CFG)
self.model = None
self.data = None
self.device = None
Expand Down Expand Up @@ -126,8 +126,7 @@ def __call__(self, trainer=None, model=None):
else:
callbacks.add_integration_callbacks(self)
self.run_callbacks('on_val_start')
assert model is not None, 'Either trainer or model is needed for validation'
model = AutoBackend(model,
model = AutoBackend(model or self.args.model,
device=select_device(self.args.device, self.args.batch),
dnn=self.args.dnn,
data=self.args.data,
Expand Down
9 changes: 5 additions & 4 deletions ultralytics/models/rtdetr/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
__all__ = 'RTDETRValidator', # tuple or list


# TODO: Temporarily, RT-DETR does not need padding.
# TODO: Temporarily RT-DETR does not need padding.
class RTDETRDataset(YOLODataset):

def __init__(self, *args, data=None, **kwargs):
Expand Down Expand Up @@ -47,7 +47,7 @@ def load_image(self, i):
return self.ims[i], self.im_hw0[i], self.im_hw[i]

def build_transforms(self, hyp=None):
"""Temporarily, only for evaluation."""
"""Temporary, only for evaluation."""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
Expand Down Expand Up @@ -76,12 +76,13 @@ class RTDETRValidator(DetectionValidator):
args = dict(model='rtdetr-l.pt', data='coco8.yaml')
validator = RTDETRValidator(args=args)
validator(model=args['model'])
validator()
```
"""

def build_dataset(self, img_path, mode='val', batch=None):
"""Build YOLO Dataset
"""
Build an RTDETR Dataset.
Args:
img_path (str): Path to the folder containing images.
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/models/yolo/classify/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ClassificationValidator(BaseValidator):
args = dict(model='yolov8n-cls.pt', data='imagenet10')
validator = ClassificationValidator(args=args)
validator(model=args['model'])
validator()
```
"""

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/models/yolo/detect/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class DetectionValidator(BaseValidator):
args = dict(model='yolov8n.pt', data='coco8.yaml')
validator = DetectionValidator(args=args)
validator(model=args['model'])
validator()
```
"""

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/models/yolo/pose/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class PoseValidator(DetectionValidator):
args = dict(model='yolov8n-pose.pt', data='coco8-pose.yaml')
validator = PoseValidator(args=args)
validator(model=args['model'])
validator()
```
"""

Expand Down
2 changes: 1 addition & 1 deletion ultralytics/models/yolo/segment/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SegmentationValidator(DetectionValidator):
args = dict(model='yolov8n-seg.pt', data='coco8-seg.yaml')
validator = SegmentationValidator(args=args)
validator(model=args['model'])
validator()
```
"""

Expand Down

0 comments on commit b2f279f

Please sign in to comment.