Skip to content

Commit

Permalink
ultralytics 8.0.108 add Meituan YOLOv6 models (ultralytics#2811)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Currie <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hicham Talaoubrid <[email protected]>
Co-authored-by: Zlobin Vladimir <[email protected]>
Co-authored-by: Szymon Mikler <[email protected]>
  • Loading branch information
6 people authored May 24, 2023
1 parent 07b57c0 commit ffc0e8c
Show file tree
Hide file tree
Showing 18 changed files with 233 additions and 45 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ model = YOLO("yolov8n.pt") # load a pretrained model (recommended for training)
model.train(data="coco128.yaml", epochs=3) # train the model
metrics = model.val() # evaluate model performance on the validation set
results = model("https://ultralytics.com/images/bus.jpg") # predict on an image
success = model.export(format="onnx") # export the model to ONNX format
path = model.export(format="onnx") # export the model to ONNX format
```

[Models](https://github.com/ultralytics/ultralytics/tree/main/ultralytics/models) download automatically from the latest Ultralytics [release](https://github.com/ultralytics/assets/releases). See YOLOv8 [Python Docs](https://docs.ultralytics.com/usage/python) for more examples.
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ADD https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt /u

# Install pip packages
RUN python3 -m pip install --upgrade pip wheel
RUN pip install --no-cache -e . albumentations comet tensorboard thop
RUN pip install --no-cache -e . albumentations comet tensorboard thop pycocotools

# Set environment variables
ENV OMP_NUM_THREADS=1
Expand Down
1 change: 1 addition & 0 deletions docs/models/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ In this documentation, we provide information on four major models:

1. [YOLOv3](./yolov3.md): The third iteration of the YOLO model family, known for its efficient real-time object detection capabilities.
2. [YOLOv5](./yolov5.md): An improved version of the YOLO architecture, offering better performance and speed tradeoffs compared to previous versions.
3. [YOLOv6](./yolov6.md): Released by [Meituan](https://about.meituan.com/) in 2022 and is in use in many of the company's autonomous delivery robots.
3. [YOLOv8](./yolov8.md): The latest version of the YOLO family, featuring enhanced capabilities such as instance segmentation, pose/keypoints estimation, and classification.
4. [Segment Anything Model (SAM)](./sam.md): Meta's Segment Anything Model (SAM).
5. [Realtime Detection Transformers (RT-DETR)](./rtdetr.md): Baidu's RT-DETR model.
Expand Down
2 changes: 1 addition & 1 deletion docs/models/sam.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ For more information about the Segment Anything Model and the SA-1B dataset, ple
SAM can be used for a variety of downstream tasks involving object and image distributions beyond its training data. Examples include edge detection, object proposal generation, instance segmentation, and preliminary text-to-mask prediction. By employing prompt engineering, SAM can adapt to new tasks and data distributions in a zero-shot manner, making it a versatile and powerful tool for image segmentation tasks.

```python
from ultralytics.vit import SAM
from ultralytics import SAM

model = SAM('sam_b.pt')
model.info() # display model information
Expand Down
81 changes: 81 additions & 0 deletions docs/models/yolov6.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
---
comments: true
description: Discover Meituan YOLOv6, a robust real-time object detector. Learn how to utilize pre-trained models with Ultralytics Python API for a variety of tasks.
---

# Meituan YOLOv6

## Overview

[Meituan](https://about.meituan.com/) YOLOv6 is a cutting-edge object detector that offers remarkable balance between speed and accuracy, making it a popular choice for real-time applications. This model introduces several notable enhancements on its architecture and training scheme, including the implementation of a Bi-directional Concatenation (BiC) module, an anchor-aided training (AAT) strategy, and an improved backbone and neck design for state-of-the-art accuracy on the COCO dataset.

![Meituan YOLOv6](https://user-images.githubusercontent.com/26833433/240750495-4da954ce-8b3b-41c4-8afd-ddb74361d3c2.png)
![Model example image](https://user-images.githubusercontent.com/26833433/240750557-3e9ec4f0-0598-49a8-83ea-f33c91eb6d68.png)
**Overview of YOLOv6.** Model architecture diagram showing the redesigned network components and training strategies that have led to significant performance improvements. (a) The neck of YOLOv6 (N and S are shown). Note for M/L, RepBlocks is replaced with CSPStackRep. (b) The
structure of a BiC module. (c) A SimCSPSPPF block. ([source](https://arxiv.org/pdf/2301.05586.pdf)).

### Key Features

- **Bi-directional Concatenation (BiC) Module:** YOLOv6 introduces a BiC module in the neck of the detector, enhancing localization signals and delivering performance gains with negligible speed degradation.
- **Anchor-Aided Training (AAT) Strategy:** This model proposes AAT to enjoy the benefits of both anchor-based and anchor-free paradigms without compromising inference efficiency.
- **Enhanced Backbone and Neck Design:** By deepening YOLOv6 to include another stage in the backbone and neck, this model achieves state-of-the-art performance on the COCO dataset at high-resolution input.
- **Self-Distillation Strategy:** A new self-distillation strategy is implemented to boost the performance of smaller models of YOLOv6, enhancing the auxiliary regression branch during training and removing it at inference to avoid a marked speed decline.

## Pre-trained Models

YOLOv6 provides various pre-trained models with different scales:

- YOLOv6-N: 37.5% AP on COCO val2017 at 1187 FPS with NVIDIA Tesla T4 GPU.
- YOLOv6-S: 45.0% AP at 484 FPS.
- YOLOv6-M: 50.0% AP at 226 FPS.
- YOLOv6-L: 52.8% AP at 116 FPS.
- YOLOv6-L6: State-of-the-art accuracy in real-time.

YOLOv6 also provides quantized models for different precisions and models optimized for mobile platforms.

## Usage

### Python API

```python
from ultralytics import YOLO

model = YOLO("yolov6n.yaml") # build new model from scratch
model.info() # display model information
model.predict("path/to/image.jpg") # predict
```

### Supported Tasks

| Model Type | Pre-trained Weights | Tasks Supported |
|------------|---------------------|------------------|
| YOLOv6-N | `yolov6-n.pt` | Object Detection |
| YOLOv6-S | `yolov6-s.pt` | Object Detection |
| YOLOv6-M | `yolov6-m.pt` | Object Detection |
| YOLOv6-L | `yolov6-l.pt` | Object Detection |
| YOLOv6-L6 | `yolov6-l6.pt` | Object Detection |

## Supported Modes

| Mode | Supported |
|------------|--------------------|
| Inference | :heavy_check_mark: |
| Validation | :heavy_check_mark: |
| Training | :heavy_check_mark: |

## Citations and Acknowledgements

We would like to acknowledge the authors for their significant contributions in the field of real-time object detection:

```bibtex
@misc{li2023yolov6,
title={YOLOv6 v3.0: A Full-Scale Reloading},
author={Chuyi Li and Lulu Li and Yifei Geng and Hongliang Jiang and Meng Cheng and Bo Zhang and Zaidan Ke and Xiaoming Xu and Xiangxiang Chu},
year={2023},
eprint={2301.05586},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

The original YOLOv6 paper can be found on [arXiv](https://arxiv.org/abs/2301.05586). The authors have made their work publicly available, and the codebase can be accessed on [GitHub](https://github.com/meituan/YOLOv6). We appreciate their efforts in advancing the field and making their work accessible to the broader community.
16 changes: 8 additions & 8 deletions docs/usage/hyperparameter_tuning.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ To install the required packages, run:

The `tune()` method in YOLOv8 provides an easy-to-use interface for hyperparameter tuning with Ray Tune. It accepts several arguments that allow you to customize the tuning process. Below is a detailed explanation of each parameter:

| Parameter | Type | Description | Default Value |
|-----------------|----------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
| `data` | str | The dataset configuration file (in YAML format) to run the tuner on. This file should specify the training and validation data paths, as well as other dataset-specific settings. | |
| `space` | dict, optional | A dictionary defining the hyperparameter search space for Ray Tune. Each key corresponds to a hyperparameter name, and the value specifies the range of values to explore during tuning. If not provided, YOLOv8 uses a default search space with various hyperparameters. | |
| `grace_period` | int, optional | The grace period in epochs for the [ASHA scheduler](https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#asha-tune-schedulers-asha) in Ray Tune. The scheduler will not terminate any trial before this number of epochs, allowing the model to have some minimum training before making a decision on early stopping. | 10 |
| `gpu_per_trial` | int, optional | The number of GPUs to allocate per trial during tuning. This helps manage GPU usage, particularly in multi-GPU environments. If not provided, the tuner will use all available GPUs. | None |
| `max_samples` | int, optional | The maximum number of trials to run during tuning. This parameter helps control the total number of hyperparameter combinations tested, ensuring the tuning process does not run indefinitely. | 10 |
| `train_args` | dict, optional | A dictionary of additional arguments to pass to the `train()` method during tuning. These arguments can include settings like the number of training epochs, batch size, and other training-specific configurations. | {} |
| Parameter | Type | Description | Default Value |
|-----------------|----------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------|
| `data` | str | The dataset configuration file (in YAML format) to run the tuner on. This file should specify the training and validation data paths, as well as other dataset-specific settings. | |
| `space` | dict, optional | A dictionary defining the hyperparameter search space for Ray Tune. Each key corresponds to a hyperparameter name, and the value specifies the range of values to explore during tuning. If not provided, YOLOv8 uses a default search space with various hyperparameters. | |
| `grace_period` | int, optional | The grace period in epochs for the [ASHA scheduler]https://docs.ray.io/en/latest/tune/api/schedulers.html) in Ray Tune. The scheduler will not terminate any trial before this number of epochs, allowing the model to have some minimum training before making a decision on early stopping. | 10 |
| `gpu_per_trial` | int, optional | The number of GPUs to allocate per trial during tuning. This helps manage GPU usage, particularly in multi-GPU environments. If not provided, the tuner will use all available GPUs. | None |
| `max_samples` | int, optional | The maximum number of trials to run during tuning. This parameter helps control the total number of hyperparameter combinations tested, ensuring the tuning process does not run indefinitely. | 10 |
| `train_args` | dict, optional | A dictionary of additional arguments to pass to the `train()` method during tuning. These arguments can include settings like the number of training epochs, batch size, and other training-specific configurations. | {} |

By customizing these parameters, you can fine-tune the hyperparameter optimization process to suit your specific needs and available computational resources.

Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ nav:
- models/index.md
- YOLOv3: models/yolov3.md
- YOLOv5: models/yolov5.md
- YOLOv6: models/yolov6.md
- YOLOv8: models/yolov8.md
- SAM (Segment Anything Model): models/sam.md
- RT-DETR (Realtime Detection Transformer): models/rtdetr.md
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

__version__ = '8.0.107'
__version__ = '8.0.108'

from ultralytics.hub import start
from ultralytics.vit.rtdetr import RTDETR
Expand Down
51 changes: 51 additions & 0 deletions ultralytics/models/v6/yolov6.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv6 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect

# Parameters
act: nn.ReLU()
nc: 80 # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov6n.yaml' will call yolov8.yaml with scale 'n'
# [depth, width, max_channels]
n: [ 0.33, 0.25, 1024 ]
s: [ 0.33, 0.50, 1024 ]
m: [ 0.67, 0.75, 768 ]
l: [ 1.00, 1.00, 512 ]
x: [ 1.00, 1.25, 512 ]

# YOLOv6-3.0s backbone
backbone:
# [from, repeats, module, args]
- [ -1, 1, Conv, [ 64, 3, 2 ] ] # 0-P1/2
- [ -1, 1, Conv, [ 128, 3, 2 ] ] # 1-P2/4
- [ -1, 6, Conv, [ 128, 3, 1 ] ]
- [ -1, 1, Conv, [ 256, 3, 2 ] ] # 3-P3/8
- [ -1, 12, Conv, [ 256, 3, 1 ] ]
- [ -1, 1, Conv, [ 512, 3, 2 ] ] # 5-P4/16
- [ -1, 18, Conv, [ 512, 3, 1 ] ]
- [ -1, 1, Conv, [ 1024, 3, 2 ] ] # 7-P5/32
- [ -1, 9, Conv, [ 1024, 3, 1 ] ]
- [ -1, 1, SPPF, [ 1024, 5 ] ] # 9

# YOLOv6-3.0s head
head:
- [ -1, 1, nn.ConvTranspose2d, [ 256, 2, 2, 0 ] ]
- [ [ -1, 6 ], 1, Concat, [ 1 ] ] # cat backbone P4
- [ -1, 1, Conv, [ 256, 3, 1 ] ]
- [ -1, 9, Conv, [ 256, 3, 1 ] ] # 13

- [ -1, 1, nn.ConvTranspose2d, [ 128, 2, 2, 0 ] ]
- [ [ -1, 4 ], 1, Concat, [ 1 ] ] # cat backbone P3
- [ -1, 1, Conv, [ 128, 3, 1 ] ]
- [ -1, 9, Conv, [ 128, 3, 1 ] ] # 17

- [ -1, 1, Conv, [ 128, 3, 2 ] ]
- [ [ -1, 12 ], 1, Concat, [ 1 ] ] # cat head P4
- [ -1, 1, Conv, [ 256, 3, 1 ] ]
- [ -1, 9, Conv, [ 256, 3, 1 ] ] # 21

- [ -1, 1, Conv, [ 256, 3, 2 ] ]
- [ [ -1, 9 ], 1, Concat, [ 1 ] ] # cat head P5
- [ -1, 1, Conv, [ 512, 3, 1 ] ]
- [ -1, 9, Conv, [ 512, 3, 1 ] ] # 25

- [ [ 17, 21, 25 ], 1, Detect, [ nc ] ] # Detect(P3, P4, P5)
11 changes: 11 additions & 0 deletions ultralytics/vit/sam/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from ultralytics.yolo.cfg import get_cfg

from ...yolo.utils.torch_utils import model_info
from .build import build_sam
from .predict import Predictor

Expand Down Expand Up @@ -33,3 +34,13 @@ def train(self, **kwargs):
def val(self, **kwargs):
"""Run validation given dataset."""
raise NotImplementedError("SAM models don't support validation")

def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose)
5 changes: 4 additions & 1 deletion ultralytics/yolo/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,13 @@ def get_cfg(cfg: Union[str, Path, Dict, SimpleNamespace] = DEFAULT_CFG_DICT, ove
check_cfg_mismatch(cfg, overrides)
cfg = {**cfg, **overrides} # merge cfg and overrides dicts (prefer overrides)

# Special handling for numeric project/names
# Special handling for numeric project/name
for k in 'project', 'name':
if k in cfg and isinstance(cfg[k], (int, float)):
cfg[k] = str(cfg[k])
if cfg.get('name') == 'model': # assign model to 'name' arg
cfg['name'] = cfg.get('model', '').split('.')[0]
LOGGER.warning(f"WARNING ⚠️ 'name=model' automatically updated to 'name={cfg['name']}'.")

# Type and Value checks
for k, v in cfg.items():
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/yolo/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def check_source(source):
is_file = Path(source).suffix[1:] in (IMG_FORMATS + VID_FORMATS)
is_url = source.lower().startswith(('https://', 'http://', 'rtsp://', 'rtmp://'))
webcam = source.isnumeric() or source.endswith('.streams') or (is_url and not is_file)
screenshot = source.lower().startswith('screen')
screenshot = source.lower() == 'screen'
if is_url and is_file:
source = check_file(source) # download
elif isinstance(source, tuple(LOADERS)):
Expand Down
8 changes: 4 additions & 4 deletions ultralytics/yolo/engine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,12 @@ def export(self, **kwargs):
overrides = self.overrides.copy()
overrides.update(kwargs)
overrides['mode'] = 'export'
if overrides.get('imgsz') is None:
overrides['imgsz'] = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if overrides.get('batch') is None:
overrides['batch'] = 1 # default to 1 if not modified
args = get_cfg(cfg=DEFAULT_CFG, overrides=overrides)
args.task = self.task
if args.imgsz == DEFAULT_CFG.imgsz:
args.imgsz = self.model.args['imgsz'] # use trained imgsz unless custom value is passed
if args.batch == DEFAULT_CFG.batch:
args.batch = 1 # default to 1 if not modified
return Exporter(overrides=args, _callbacks=self.callbacks)(model=self.model)

def train(self, **kwargs):
Expand Down
7 changes: 6 additions & 1 deletion ultralytics/yolo/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,17 @@ def amp_allclose(m, im):
im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if ONLINE else np.ones((640, 640, 3))
prefix = colorstr('AMP: ')
LOGGER.info(f'{prefix}running Automatic Mixed Precision (AMP) checks with YOLOv8n...')
warning_msg = "Setting 'amp=True'. If you experience zero-mAP or NaN losses you can disable AMP with amp=False."
try:
from ultralytics import YOLO
assert amp_allclose(YOLO('yolov8n.pt'), im)
LOGGER.info(f'{prefix}checks passed ✅')
except ConnectionError:
LOGGER.warning(f"{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. Setting 'amp=True'.")
LOGGER.warning(f'{prefix}checks skipped ⚠️, offline and unable to download YOLOv8n. {warning_msg}')
except (AttributeError, ModuleNotFoundError):
LOGGER.warning(
f'{prefix}checks skipped ⚠️. Unable to load YOLOv8n due to possible Ultralytics package modifications. {warning_msg}'
)
except AssertionError:
LOGGER.warning(f'{prefix}checks failed ❌. Anomalies were detected with AMP on your system that may lead to '
f'NaN losses or zero-mAP results, so AMP will be disabled during training.')
Expand Down
9 changes: 6 additions & 3 deletions ultralytics/yolo/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,12 +372,15 @@ def is_online() -> bool:
"""
import socket

for server in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
for host in '1.1.1.1', '8.8.8.8', '223.5.5.5': # Cloudflare, Google, AliDNS:
try:
socket.create_connection((server, 53), timeout=2) # connect to (server, port=53)
return True
test_connection = socket.create_connection(address=(host, 53), timeout=2)
except (socket.timeout, socket.gaierror, OSError):
continue
else:
# If the connection was successful, close it to avoid a ResourceWarning
test_connection.close()
return True
return False


Expand Down
Loading

0 comments on commit ffc0e8c

Please sign in to comment.