解决方案:结合 YOLO 和 TensorFlow 做目标检测和图像分类

开发 人工智能
在本文中,我将向你解释什么是目标检测和图像分类,如何训练模型,最后,我将同时使用目标检测和图像分类模型来分类和检测狗的品种。

目标检测和图像分类是两个不同的任务,每个任务都有其特定的用途。在本文中,我将向你解释什么是目标检测和图像分类,如何训练模型,最后,我将同时使用目标检测和图像分类模型来分类和检测狗的品种。

目标检测 + 图像分类

目标检测

目标检测是一项基本的计算机视觉任务,用于检测和定位物体。简而言之,目标检测模型接受图像作为输入,并输出坐标和标签。

目标检测示例

如果你需要知道物体的坐标,那么你需要使用目标检测模型,因为图像分类模型不会输出坐标,它们只返回标签。

图像分类

图像分类只输出标签。你将图像作为输入提供给模型,它将标签作为输出返回。它更适合分类相同类型的物体。例如,如果你想分类海洋动物,你需要训练一个图像分类模型。

海洋动物分类

为什么不只使用目标检测模型?

你可能已经注意到,目标检测模型同时提供坐标和标签作为输出,你可能会对自己说,为什么不直接使用目标检测模型来处理所有事情呢?毕竟,它们理论上同时给出坐标和标签,所以不需要分类模型。你可能一开始会这样想,但有一些不同的因素你可能没有意识到:

目标检测模型非常适合识别和定位场景中的各种物体。但当涉及到区分几乎相同的物体时,图像分类模型通常表现更好(一般来说,并非总是如此)。

你并不总能找到合适的数据集,创建数据集可能既耗时又无聊。如果你决定创建自己的数据集,创建图像分类数据集比创建目标检测数据集要容易得多。

目标检测 + 图像分类

工作流程

首先,我们将使用yolov8目标检测模型检测物体,然后从检测到的物体中,我们将尝试使用图像分类模型对它们进行分类。请注意,图像分类模型将仅对检测到的物体进行操作,而不是对整个图像进行操作。

1. 用于检测狗的目标检测模型

首先,可以参考文章《定制YOLOv8模型 - 检测棋盘棋子》训练一个YOLO目标检测模型。现在,我将使用一个预训练的YOLOv8模型,因为它包含狗类,我将直接使用这个预训练模型。我将使用YOLO模型进行检测,如果它检测到狗,我将使用图像分类模型继续处理。一般来说,最好为特定任务使用特定数据集训练模型。

2. 用于狗品种的图像分类模型

我将使用TensorFlow来训练图像分类模型。训练模型可能需要时间,具体取决于数据集和参数。有关狗品种分类模型的详细代码可以查看链接:https://www.kaggle.com/code/merfarukgnaydn/dog-species-classification

使用TensorFlow Keras训练图像分类模型

3. 结合目标检测 + 图像分类模型

正如我之前向你解释的那样,这个过程非常简单。首先,目标检测模型执行,然后是图像分类模型。下面是代码和相关注释:


# 导入库
import cv2
import numpy as np
from ultralytics import YOLO
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt

# 加载YOLO检测模型
yolo_model = YOLO("yolov8s.pt")  # 替换为你的YOLO模型路径

# 加载分类模型,你可以运行notebook并保存模型并使用它(查看步骤2)
classification_model = load_model('dog_classification_model.h5')

# 分类标签
species_list = ['afghan_hound', 'african_hunting_dog', 'airedale', 'basenji', 'basset', 'beagle', 
                'bedlington_terrier', 'bernese_mountain_dog', 'black-and-tan_coonhound', 
                'blenheim_spaniel', 'bloodhound', 'bluetick', 'border_collie', 'border_terrier', 
                'borzoi', 'boston_bull', 'bouvier_des_flandres', 'brabancon_griffon', 'bull_mastiff', 
                'cairn', 'cardigan', 'chesapeake_bay_retriever', 'chow', 'clumber', 'cocker_spaniel', 
                'collie', 'curly-coated_retriever', 'dhole', 'dingo', 'doberman', 'english_foxhound', 
                'english_setter', 'entlebucher', 'flat-coated_retriever', 'german_shepherd', 
                'german_short-haired_pointer', 'golden_retriever', 'gordon_setter', 'great_dane', 
                'great_pyrenees', 'groenendael', 'ibizan_hound', 'irish_setter', 'irish_terrier', 
                'irish_water_spaniel', 'irish_wolfhound', 'japanese_spaniel', 'keeshond', 
                'kerry_blue_terrier', 'komondor', 'kuvasz', 'labrador_retriever', 'leonberg', 
                'lhasa', 'malamute', 'malinois', 'maltese_dog', 'mexican_hairless', 'miniature_pinscher', 
                'miniature_schnauzer', 'newfoundland', 'norfolk_terrier', 'norwegian_elkhound', 
                'norwich_terrier', 'old_english_sheepdog', 'otterhound', 'papillon', 'pekinese', 
                'pembroke', 'pomeranian', 'pug', 'redbone', 'rhodesian_ridgeback', 'rottweiler', 
                'saint_bernard', 'saluki', 'samoyed', 'schipperke', 'scotch_terrier', 
                'scottish_deerhound', 'sealyham_terrier', 'shetland_sheepdog', 'standard_poodle', 
                'standard_schnauzer', 'sussex_spaniel', 'tibetan_mastiff', 'tibetan_terrier', 
                'toy_terrier', 'vizsla', 'weimaraner', 'whippet', 'wire-haired_fox_terrier', 
                'yorkshire_terrier']
    
# 执行推理
def classify_region(image, model, target_size=(180, 180)):  # 尺寸必须与分类模型的输入匹配
    input_image = preprocess_image(image, target_size)
    predictions = model.predict(input_image)
    predicted_index = np.argmax(predictions[0])
    predicted_label = species_list[predicted_index]
    return predicted_label

# 加载图像
image_path = r"test-images/dog12.jpg"  # 图像路径
image = cv2.imread(image_path)

# YOLO推理 --> 目标检测模型
results = yolo_model(image)
detections = results[0].boxes  # 获取检测结果

# 检查YOLO的标签是否为"dog"并处理边界框
for detection in detections:
    x1, y1, x2, y2 = map(int, detection.xyxy[0].tolist())  # 获取边界框坐标
    conf = float(detection.conf[0])  # 获取置信度
    cls_label = yolo_model.names[int(detection.cls[0])]  # 直接从YOLO获取标签名称

    # 检查标签是否为"dog"
    if cls_label == "dog":

        """
        提取用于分类的感兴趣区域(ROI)。
        记住,图像分类模型只会对检测到的对象进行处理,
        而不是对整个图像进行处理。
        """
        roi = image[y1:y2, x1:x2]

        # 如果ROI足够大,则对其进行分类
        if roi.shape[0] > 0 and roi.shape[1] > 0:
            # 图像分类模型
            label = classify_region(roi, classification_model)

            bbox_height = y2 - y1
            font_scale = bbox_height / 200.0  # 比例因子,可根据需要调整
            font_thickness = max(1, int(bbox_height / 100))  # 确保厚度至少为1

            # 绘制边界框和标签
            cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 4)
            cv2.putText(image, label, (x1+100, y1-20), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 255), font_thickness)
            print(f"检测到的狗品种: {label}")

cv2.imwrite("dog2-result.jpg", image)
# 显示结果图像
plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
plt.axis("off")
plt.show()
  • 1.
  • 2.
  • 3.
  • 4.
  • 5.
  • 6.
  • 7.
  • 8.
  • 9.
  • 10.
  • 11.
  • 12.
  • 13.
  • 14.
  • 15.
  • 16.
  • 17.
  • 18.
  • 19.
  • 20.
  • 21.
  • 22.
  • 23.
  • 24.
  • 25.
  • 26.
  • 27.
  • 28.
  • 29.
  • 30.
  • 31.
  • 32.
  • 33.
  • 34.
  • 35.
  • 36.
  • 37.
  • 38.
  • 39.
  • 40.
  • 41.
  • 42.
  • 43.
  • 44.
  • 45.
  • 46.
  • 47.
  • 48.
  • 49.
  • 50.
  • 51.
  • 52.
  • 53.
  • 54.
  • 55.
  • 56.
  • 57.
  • 58.
  • 59.
  • 60.
  • 61.
  • 62.
  • 63.
  • 64.
  • 65.
  • 66.
  • 67.
  • 68.
  • 69.
  • 70.
  • 71.
  • 72.
  • 73.
  • 74.
  • 75.
  • 76.
  • 77.
  • 78.
  • 79.
  • 80.
  • 81.
  • 82.
  • 83.
  • 84.
  • 85.
  • 86.
  • 87.
责任编辑:赵宁宁 来源: 小白玩转Python
相关推荐

2022-06-29 09:00:00

前端图像分类模型SQL

2022-12-13 10:13:09

智能驾驶

2018-04-09 10:20:32

深度学习

2020-10-27 09:37:43

PyTorchTensorFlow机器学习

2024-10-09 17:02:34

2009-12-23 21:06:47

统一通信多媒体联络中心平台华为

2011-02-23 17:13:19

FileZilla

2023-09-13 10:52:56

2024-03-07 12:31:07

2009-07-17 09:17:41

IT运维SiteView游龙科技

2013-03-01 14:38:01

2025-01-02 10:30:00

无人机目标检测AI

2025-02-18 08:00:00

C++YOLO目标检测

2025-01-22 11:10:34

2024-11-29 16:10:31

2017-10-14 21:24:33

TensorFlow目标检测模型

2019-03-13 08:43:32

边缘计算物联网IoT

2024-09-02 09:48:08

API编排GraphQL

2022-06-28 07:28:43

远程医疗IT监测

2023-01-06 19:02:23

应用技术
点赞
收藏

51CTO技术栈公众号