Skip to content

Commit

Permalink
Version 5.7.1
Browse files Browse the repository at this point in the history
添加了 `conf_thres` 和 `iou_thres` 的设置方法,在初始化识别方法时可以添加。
  • Loading branch information
Ender-William committed Mar 29, 2023
1 parent 1429233 commit 0c0a4ea
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 13 deletions.
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions .idea/YoloDetectAPI.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions .idea/inspectionProfiles/Project_Default.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 4 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@



# Introduction

YoloV5 作为 YoloV4 之后的改进型,在算法上做出了优化,检测的性能得到了一定的提升。其特点之一就是权重文件非常的小,可以在一些配置更低的移动设备上运行,且提高速度的同时准确度更高。本次使用的是最新推出的 YoloV5 Version7 版本。
Expand Down Expand Up @@ -55,7 +58,7 @@ import torch

if __name__ == '__main__':
cap = cv2.VideoCapture(0)
a = yolo_detectAPI.DetectAPI(weights='last.pt') # 你要使用的模型的路径
a = yolo_detectAPI.DetectAPI(weights='last.pt', conf_thres=0.5, iou_thres=0.5) # 你要使用的模型的路径
with torch.no_grad():
while True:
rec, img = cap.read()
Expand All @@ -65,7 +68,7 @@ if __name__ == '__main__':
for cls, (x1, y1, x2, y2), conf in result[0][1]:
print(names[cls], x1, y1, x2, y2, conf) # 识别物体种类、左上角x坐标、左上角y轴坐标、右下角x轴坐标、右下角y轴坐标,置信度

cv2.imshow("vedio", img)
cv2.imshow("video", img)

if cv2.waitKey(1) == ord('q'):
break
Expand Down Expand Up @@ -97,3 +100,12 @@ if __name__ == '__main__':
https://github.com/ultralytics/yolov5/releases/tag/v7.0
https://blog.csdn.net/weixin_51331359/article/details/126012620
https://blog.csdn.net/CharmsLUO/article/details/123422822

# Update Version 5.7.1 2023-03-29
添加了 `conf_thres``iou_thres` 的设置方法,在初始化识别方法时可以添加。
```python
yolo_detectAPI.DetectAPI(weights='last.pt', conf_thres=0.5, iou_thres=0.5)
```
`iou_thres` 过大容易出现一个目标多个检测框;

`iou_thres` 过小容易出现检测结果少的问题。
11 changes: 7 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,19 @@

setup(
name='yolo_detectAPI',
version='5.7',
version='5.7.1',
description='Detect API',
long_description='This is a API for yolov5 version7 detect.py',
long_description='This is a API for yolov5 version7 detect.py, new-version 5.7.1 allow user set <conf_thres> and '
'<iou_thres>',
license='GPL Licence',
author='Da Kuang',
author_email='[email protected]',
py_modeles = '__init__.py',
packages=find_packages(),
pakages=['yolo_detectAPI'],
include_package_data=True,
python_requires='>=3.8',
readme = 'README.md',
python_requires='>=3.7',
url = 'http://blogs.kd-mercury.xyz/',
install_requires=['matplotlib>=3.2.2', 'numpy>=1.18.5', 'opencv-python>=4.1.1',
'Pillow>=7.1.2', 'PyYAML>=5.3.1', 'requests>=2.23.0', 'scipy>=1.4.1',
Expand All @@ -22,10 +24,11 @@
'ipython>=8.3.0', 'psutil>=5.9.4'],
data_files=['export.py'],
classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"License :: OSI Approved :: GNU General Public License (GPL)",
"Development Status :: 4 - Beta"
"Development Status :: 3 - Alpha"
],
scripts=[],
)
Binary file added whl/yolo_detectAPI-5.7.1-py3-none-any.whl
Binary file not shown.
28 changes: 21 additions & 7 deletions yolo_detectAPI/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#-*- coding:utf-8 -*-
# -*- coding:utf-8 -*-
import os
import random
import sys
Expand Down Expand Up @@ -53,8 +53,20 @@ def __init__(self, weights='weights/last.pt',


class DetectAPI:
def __init__(self, weights, imgsz=640):
def __init__(self, weights, imgsz=640, conf_thres=None, iou_thres=None):
"""
Init Detect API
Args:
weights: model
imgsz: default 640
conf_thres: 用于物体的识别率,object置信度阈值 默认0.25,大于此准确率才会显示识别结果
iou_thres: 用于去重,做nms的iou阈值 默认0.45,数值越小去重程度越高
"""
self.opt = YoloOpt(weights=weights, imgsz=imgsz)
if conf_thres is not None:
self.opt.conf_thres = conf_thres
if iou_thres is not None:
self.opt.iou_thres = iou_thres
weights = self.opt.weights
imgsz = self.opt.imgsz

Expand All @@ -73,7 +85,7 @@ def __init__(self, weights, imgsz=640):

# 不使用半精度
if self.half:
self.model.half() # switch to FP16
self.model.half() # switch to FP16

# read names and colors
self.names = self.model.module.names if hasattr(self.model, 'module') else self.model.names
Expand All @@ -88,7 +100,7 @@ def detect(self, source):
# 直接从 source 加载数据
dataset = LoadImages(source)
# 源程序通过路径加载数据,现在 source 就是加载好的数据,因此 LoadImages 就要重写
bs = 1 # set batch size
bs = 1 # set batch size

# 保存的路径
vid_path, vid_writer = [None] * bs, [None] * bs
Expand All @@ -113,14 +125,15 @@ def detect(self, source):

# NMS
with dt[2]:
pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, self.opt.classes, self.opt.agnostic_nms, max_det=2)
pred = non_max_suppression(pred, self.opt.conf_thres, self.opt.iou_thres, self.opt.classes,
self.opt.agnostic_nms, max_det=2)

# Process predictions
# 处理每一张图片
det = pred[0] # API 一次只处理一张图片,因此不需要 for 循环
im0 = im0s.copy() # copy 一个原图片的副本图片
result_txt = [] # 储存检测结果,每新检测出一个物品,长度就加一。
# 每一个元素是列表形式,储存着 类别,坐标,置信度
# 每一个元素是列表形式,储存着 类别,坐标,置信度
# 设置图片上绘制框的粗细,类别名称
annotator = Annotator(im0, line_width=3, example=str(self.names))
if len(det):
Expand All @@ -137,6 +150,7 @@ def detect(self, source):
result.append((im0, result_txt)) # 对于每张图片,返回画完框的图片,以及该图片的标签列表。
return result, self.names


if __name__ == '__main__':
cap = cv2.VideoCapture(0)
a = DetectAPI(weights='weights/last.pt')
Expand All @@ -155,4 +169,4 @@ def detect(self, source):
cv2.imshow("vedio", img)

if cv2.waitKey(1) == ord('q'):
break
break
34 changes: 34 additions & 0 deletions yolo_detectAPI/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
from setuptools import setup, find_packages

setup(
name='yolo_detectAPI',
version='5.7.1',
description='Detect API',
long_description='This is a API for yolov5 version7 detect.py, new-version 5.7.1 allow user set <conf_thres> and '
'<iou_thres>',
license='GPL Licence',
author='Da Kuang',
author_email='[email protected]',
py_modeles = '__init__.py',
packages=find_packages(),
pakages=['yolo_detectAPI'],
include_package_data=True,
readme = 'README.md',
python_requires='>=3.7',
url = 'http://blogs.kd-mercury.xyz/',
install_requires=['matplotlib>=3.2.2', 'numpy>=1.18.5', 'opencv-python>=4.1.1',
'Pillow>=7.1.2', 'PyYAML>=5.3.1', 'requests>=2.23.0', 'scipy>=1.4.1',
'thop>=0.1.1', 'torch>=1.7.0', 'torchvision>=0.8.1', 'tqdm>=4.64.0',
'tensorboard>=2.4.1', 'pandas>=1.1.4', 'seaborn>=0.11.0',
'ipython>=8.3.0', 'psutil>=5.9.4'],
data_files=['export.py'],
classifiers=[
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"License :: OSI Approved :: GNU General Public License (GPL)",
"Development Status :: 3 - Alpha"
],
scripts=[],
)

0 comments on commit 0c0a4ea

Please sign in to comment.