Skip to content

Commit

Permalink
Fix LoadStreams final frame bug (ultralytics#4387)
Browse files Browse the repository at this point in the history
Co-authored-by: Nadim Bou Alwan <[email protected]>
  • Loading branch information
glenn-jocher and nadinator authored Aug 16, 2023
1 parent 17e6b9c commit fb1ae9b
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 67 deletions.
2 changes: 1 addition & 1 deletion docs/guides/kfold-cross-validation.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ fold_lbl_distrb.to_csv(save_path / "kfold_label_distribution.csv")
results = {}
for k in range(ksplit):
dataset_yaml = ds_yamls[k]
results = model.train(data=dataset_yaml, *args, **kwargs) # Include any training arguments
model.train(data=dataset_yaml, *args, **kwargs) # Include any training arguments
results[k] = model.metrics # save output metrics for further analysis
```

Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ norecursedirs =
build
addopts =
--doctest-modules
--durations=25
--durations=30
--color=yes

[coverage:run]
Expand Down
25 changes: 25 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

import shutil
from pathlib import Path

import pytest

from ultralytics.utils import ROOT

TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files


def pytest_addoption(parser):
parser.addoption('--runslow', action='store_true', default=False, help='run slow tests')
Expand All @@ -19,3 +26,21 @@ def pytest_collection_modifyitems(config, items):
for item in items:
if 'slow' in item.keywords:
item.add_marker(skip_slow)


def pytest_sessionstart(session):
"""
Called after the 'Session' object has been created and before performing test collection.
"""
shutil.rmtree(TMP, ignore_errors=True) # delete any existing tests/tmp directory
TMP.mkdir(parents=True, exist_ok=True) # create a new empty directory


def pytest_terminal_summary(terminalreporter, exitstatus, config):
# Remove files
for file in ['bus.jpg', 'decelera_landscape_min.mov']:
Path(file).unlink(missing_ok=True)

# Remove directories
for directory in ['.pytest_cache/', TMP]:
shutil.rmtree(directory, ignore_errors=True)
17 changes: 1 addition & 16 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from ultralytics.utils import ONLINE, ROOT, SETTINGS
from ultralytics.utils import ROOT, SETTINGS

WEIGHT_DIR = Path(SETTINGS['weights_dir'])
TASK_ARGS = [
Expand All @@ -30,7 +30,6 @@ def test_special_modes():
run('yolo checks')
run('yolo version')
run('yolo settings reset')
run('yolo copy-cfg')
run('yolo cfg')


Expand All @@ -49,28 +48,14 @@ def test_predict(task, model, data):
run(f"yolo predict model={WEIGHT_DIR / model}.pt source={ROOT / 'assets'} imgsz=32 save save_crop save_txt")


@pytest.mark.skipif(not ONLINE, reason='environment is offline')
@pytest.mark.parametrize('task,model,data', TASK_ARGS)
def test_predict_online(task, model, data):
mode = 'track' if task in ('detect', 'segment', 'pose') else 'predict' # mode for video inference
model = WEIGHT_DIR / model
run(f'yolo predict model={model}.pt source=https://ultralytics.com/images/bus.jpg imgsz=32')
run(f'yolo {mode} model={model}.pt source=https://ultralytics.com/assets/decelera_landscape_min.mov imgsz=96')

# Run Python YouTube tracking because CLI is broken. TODO: fix CLI YouTube
# run(f'yolo {mode} model={model}.pt source=https://youtu.be/G17sBkb38XQ imgsz=32 tracker=bytetrack.yaml')


@pytest.mark.parametrize('model,format', EXPORT_ARGS)
def test_export(model, format):
run(f'yolo export model={WEIGHT_DIR / model}.pt format={format} imgsz=32')


# Test SAM, RTDETR Models
def test_rtdetr(task='detect', model='yolov8n-rtdetr.yaml', data='coco8.yaml'):
# Warning: MUST use imgsz=640
run(f'yolo train {task} model={model} data={data} imgsz=640 epochs=1 cache=disk')
run(f'yolo val {task} model={model} data={data} imgsz=640')
run(f"yolo predict {task} model={model} source={ROOT / 'assets/bus.jpg'} imgsz=640 save save_crop save_txt")


Expand Down
76 changes: 49 additions & 27 deletions tests/test_python.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,27 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license

import shutil
from copy import copy
from pathlib import Path

import cv2
import numpy as np
import pytest
import torch
from PIL import Image
from torchvision.transforms import ToTensor

from ultralytics import RTDETR, YOLO
from ultralytics.data.build import load_inference_source
from ultralytics.utils import LINUX, MACOS, ONLINE, ROOT, SETTINGS
from ultralytics.utils import DEFAULT_CFG, LINUX, ONLINE, ROOT, SETTINGS
from ultralytics.utils.downloads import download
from ultralytics.utils.torch_utils import TORCH_1_9

WEIGHTS_DIR = Path(SETTINGS['weights_dir'])
MODEL = WEIGHTS_DIR / 'path with spaces' / 'yolov8n.pt' # test spaces in path
CFG = 'yolov8n.yaml'
SOURCE = ROOT / 'assets/bus.jpg'
TMP = (ROOT / '../tests/tmp').resolve() # temp directory for test files
SOURCE_GREYSCALE = Path(f'{SOURCE.parent / SOURCE.stem}_greyscale.jpg')
SOURCE_RGBA = Path(f'{SOURCE.parent / SOURCE.stem}_4ch.png')

# Convert SOURCE to greyscale and 4-ch
im = Image.open(SOURCE)
im.convert('L').save(SOURCE_GREYSCALE) # greyscale
im.convert('RGBA').save(SOURCE_RGBA) # 4-ch PNG with alpha


def test_model_forward():
Expand Down Expand Up @@ -84,30 +80,39 @@ def test_predict_img():


def test_predict_grey_and_4ch():
# Convert SOURCE to greyscale and 4-ch
im = Image.open(SOURCE)
source_greyscale = Path(f'{SOURCE.parent / SOURCE.stem}_greyscale.jpg')
source_rgba = Path(f'{SOURCE.parent / SOURCE.stem}_4ch.png')
im.convert('L').save(source_greyscale) # greyscale
im.convert('RGBA').save(source_rgba) # 4-ch PNG with alpha

# Inference
model = YOLO(MODEL)
for f in SOURCE_RGBA, SOURCE_GREYSCALE:
for f in source_rgba, source_greyscale:
for source in Image.open(f), cv2.imread(str(f)), f:
model(source, save=True, verbose=True, imgsz=32)

# Cleanup
source_greyscale.unlink()
source_rgba.unlink()


@pytest.mark.skipif(not ONLINE, reason='environment is offline')
def test_track_stream():
# Test YouTube streaming inference (short 10 frame video) with non-default ByteTrack tracker
# imgsz=160 required for tracking for higher confidence and better matches
model = YOLO(MODEL)
model.track('https://youtu.be/G17sBkb38XQ', imgsz=96, tracker='bytetrack.yaml')
model.predict('https://youtu.be/G17sBkb38XQ', imgsz=96)
model.track('https://ultralytics.com/assets/decelera_portrait_min.mov', imgsz=160, tracker='bytetrack.yaml')
model.track('https://ultralytics.com/assets/decelera_portrait_min.mov', imgsz=160, tracker='botsort.yaml')


def test_val():
model = YOLO(MODEL)
model.val(data='coco8.yaml', imgsz=32)


def test_amp():
if torch.cuda.is_available():
from ultralytics.utils.checks import check_amp
model = YOLO(MODEL).model.cuda()
assert check_amp(model)


def test_train_scratch():
model = YOLO(CFG)
model.train(data='coco8.yaml', epochs=1, imgsz=32, cache='disk', batch=-1) # test disk caching with AutoBatch
Expand All @@ -133,10 +138,9 @@ def test_export_onnx():


def test_export_openvino():
if not MACOS:
model = YOLO(MODEL)
f = model.export(format='openvino')
YOLO(f)(SOURCE) # exported model inference
model = YOLO(MODEL)
f = model.export(format='openvino')
YOLO(f)(SOURCE) # exported model inference


def test_export_coreml(): # sourcery skip: move-assign
Expand Down Expand Up @@ -173,7 +177,7 @@ def test_all_model_yamls():
for m in (ROOT / 'cfg' / 'models').rglob('*.yaml'):
if 'rtdetr' in m.name:
if TORCH_1_9: # torch<=1.8 issue - TypeError: __init__() got an unexpected keyword argument 'batch_first'
RTDETR(m.name)
RTDETR(m.name)(SOURCE, imgsz=640)
else:
YOLO(m.name)

Expand Down Expand Up @@ -225,17 +229,14 @@ def test_results():
print(getattr(r, k))


@pytest.mark.skipif(not ONLINE, reason='environment is offline')
def test_data_utils():
# Test functions in ultralytics/data/utils.py
from ultralytics.data.utils import HUBDatasetStats, autosplit, zip_directory
from ultralytics.utils.downloads import download

# from ultralytics.utils.files import WorkingDirectory
# with WorkingDirectory(ROOT.parent / 'tests'):

shutil.rmtree(TMP, ignore_errors=True)
TMP.mkdir(parents=True)

download('https://github.com/ultralytics/hub/raw/master/example_datasets/coco8.zip', unzip=False)
shutil.move('coco8.zip', TMP)
stats = HUBDatasetStats(TMP / 'coco8.zip', task='detect')
Expand All @@ -244,4 +245,25 @@ def test_data_utils():

autosplit(TMP / 'coco8')
zip_directory(TMP / 'coco8/images/val') # zip
shutil.rmtree(TMP)


@pytest.mark.skipif(not ONLINE, reason='environment is offline')
def test_data_converter():
# Test dataset converters
from ultralytics.data.converter import convert_coco

file = 'instances_val2017.json'
download(f'https://github.com/ultralytics/yolov5/releases/download/v1.0/{file}')
shutil.move(file, TMP)
convert_coco(labels_dir=TMP, use_segments=True, use_keypoints=False, cls91to80=True)


def test_events():
# Test event sending
from ultralytics.hub.utils import Events

events = Events()
events.enabled = True
cfg = copy(DEFAULT_CFG) # does not require deepcopy
cfg.mode = 'test'
events(cfg)
2 changes: 1 addition & 1 deletion ultralytics/cfg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,5 +442,5 @@ def copy_default_cfg():


if __name__ == '__main__':
# Example Usage: entrypoint(debug='yolo predict model=yolov8n.pt')
# Example: entrypoint(debug='yolo predict model=yolov8n.pt')
entrypoint(debug='')
9 changes: 5 additions & 4 deletions ultralytics/data/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def convert_coco(labels_dir='../coco/annotations/', use_segments=False, use_keyp
use_keypoints (bool, optional): Whether to include keypoint annotations in the output.
cls91to80 (bool, optional): Whether to map 91 COCO class IDs to the corresponding 80 COCO class IDs.
Raises:
FileNotFoundError: If the labels_dir path does not exist.
Example:
```python
from ultralytics.data.converter import convert_coco
Example Usage:
convert_coco(labels_dir='../coco/annotations/', use_segments=True, use_keypoints=True, cls91to80=True)
convert_coco('../datasets/coco/annotations/', use_segments=True, use_keypoints=False, cls91to80=True)
```
Output:
Generates output files in the specified output directory.
Expand Down
9 changes: 4 additions & 5 deletions ultralytics/data/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,18 @@ def __init__(self, sources='file.streams', imgsz=640, vid_stride=1):
def update(self, i, cap, stream):
"""Read stream `i` frames in daemon thread."""
n, f = 0, self.frames[i] # frame number, frame array
while self.running and cap.isOpened() and n < f:
while self.running and cap.isOpened() and n < (f - 1):
# Only read a new frame if the buffer is empty
if not self.imgs[i]:
n += 1
cap.grab() # .read() = .grab() followed by .retrieve()
if n % self.vid_stride == 0:
success, im = cap.retrieve()
if success:
self.imgs[i].append(im) # add image to buffer
else:
if not success:
im = np.zeros(self.shape[i], dtype=np.uint8)
LOGGER.warning('WARNING ⚠️ Video stream unresponsive, please check your IP camera connection.')
self.imgs[i].append(np.zeros(self.shape[i]))
cap.open(stream) # re-open stream if signal was lost
self.imgs[i].append(im) # add image to buffer
else:
time.sleep(0.01) # wait until the buffer is empty

Expand Down
1 change: 1 addition & 0 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,7 @@ def export_ncnn(self, prefix=colorstr('ncnn:')):
yaml_save(f / 'metadata.yaml', self.metadata) # add metadata.yaml
return str(f), None

@try_export
def export_coreml(self, prefix=colorstr('CoreML:')):
"""YOLOv8 CoreML export."""
mlmodel = self.args.format.lower() == 'mlmodel' # legacy *.mlmodel export format requested
Expand Down
9 changes: 0 additions & 9 deletions ultralytics/nn/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,6 @@ def get_equivalent_kernel_bias(self):
kernelid, biasid = self._fuse_bn_tensor(self.bn)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

def _avg_to_3x3_tensor(self, avgp):
channels = self.c1
groups = self.g
kernel_size = avgp.kernel_size
input_dim = channels // groups
k = torch.zeros((channels, input_dim, kernel_size, kernel_size))
k[np.arange(channels), np.tile(np.arange(input_dim), groups), :, :] = 1.0 / kernel_size ** 2
return k

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
Expand Down
12 changes: 9 additions & 3 deletions ultralytics/utils/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,11 +437,17 @@ def check_amp(model):
Args:
model (nn.Module): A YOLOv8 model instance.
Example:
```python
from ultralytics import YOLO
from ultralytics.utils.checks import check_amp
model = YOLO('yolov8n.pt').model.cuda()
check_amp(model)
```
Returns:
(bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
Raises:
AssertionError: If the AMP checks fail, indicating anomalies with the AMP functionality on the system.
"""
device = next(model.parameters()).device # get model device
if device.type in ('cpu', 'mps'):
Expand Down

0 comments on commit fb1ae9b

Please sign in to comment.