使用回调函数训练YOLO模型

开发
在本文章中,我将向你展示一些示例,演示在训练YOLO模型时如何使用回调函数。在本例中,我将使用YOLOv8,但请注意,这可以扩展到其他一些YOLO模型,比如YOLO-NAS。

大多数人可能熟悉如何训练计算机视觉模型,比如流行的YOLO模型,甚至知道如何使用这些模型进行预测。但你知道我们可以通过回调函数为这些模型增加一些灵活性,以便在模型训练和模型推断中使用吗?大多数最先进的(SOTA)YOLO模型,如YOLOv8和YOLO-NAS,都实现了回调函数,我们可以调整这些函数以有效地利用我们的计算机视觉模型的训练和推断。

考虑以下情景。假设你是一名计算机视觉工程师,与团队中的许多工程师一起工作。你正在使用自定义数据集训练自定义的计算机视觉模型(也许是YOLO),以实现一些业务逻辑。你负责实现训练和推断逻辑。除此之外,你还需要报告模型的训练进度、训练模型的准确性等。作为一名工程师,你决定在很多个epoch上训练你的模型,这可能需要几天的时间,具体取决于一些因素,比如数据集的数量、服务器资源等。你需要密切关注模型的训练进度,因为由于诸如服务器资源问题等原因,模型可能在一段时间后停止训练,导致训练崩溃。你可能也希望在模型训练完成后收到自动警报,比如在训练结束后收到带有验证指标的电子邮件,或者在模型训练完成后自动向团队负责人发送报告。这些以及许多其他事情都是你作为计算机视觉工程师可能想要做的事情。

要实现以上任何一种情况,我们需要一种回调函数。这就是在训练计算机视觉模型时回调函数的作用。好消息是,大多数SOTA YOLO模型默认实现了这些回调函数。例如,默认情况下,YOLOv8和YOLO-NAS实现了这些回调函数,你可以在训练或进行模型预测时有效地利用它们。在本文章中,我将向你展示一些示例,演示在训练YOLO模型时如何使用回调函数。在本例中,我将使用YOLOv8,但请注意,这可以扩展到其他一些YOLO模型,比如YOLO-NAS。

让我们继续演示如何在YOLOv8上实现回调函数。我们将编写代码并在自定义数据集上训练我们的模型。我们将实现回调函数。其中一个功能是在模型训练结束后向我们的团队工程师发送电子邮件。我们发送的电子邮件将包含受过训练模型的报告,如指标、训练模型所花费的时间等。

项目实施步骤

第1步:创建一个文件夹并给它命名(在我的案例中,我将我的文件夹命名为“yolo_with_callbacks”)。

在你创建的文件夹中,创建一个新的文本文件(requirements.txt)并添加以下内容:

opencv-python==4.8.1.78
Pillow==10.0.1
tqdm==4.66.1
ultralytics==8.1.2
python-dotenv==1.0.1

然后,在你的项目文件夹中创建一个Python虚拟环境,并安装requirements.txt文件中列出的依赖项。

python3 -m venv env

接下来,通过运行以下命令激活新创建的虚拟环境:

source env/bin/activate  # if you are using Ubuntu
source env/Scripts/activate  # if you are using Windows

然后,通过运行以下命令安装依赖项:

pip install -r requirements.txt

第2步:下载一个用于自定义模型训练的示例数据集。

你可以使用任何你选择的数据集,只要注释是以YOLO格式提供的即可。在我的案例中,为了本教程的目的,我将使用来自Roboflow的POTHOLE数据集,你可以从这个链接下载:POTHOLE数据集。下载数据集后,你将得到三个文件夹(train、val和test)。现在,在你的项目目录中创建一个数据集文件夹,并将你下载的数据集(train、val和test)复制到这个文件夹中。你的数据集文件夹应该如下所示:

Datasets
    └── train
        ├── images
        └── labels
    └── val
        ├── images
        └── labels
    └── test
        ├── images
        └── labels

接下来,在项目根目录中创建一个数据集配置文件(我们称之为data.yaml)并在YAML文件中添加以下内容:

train: ./dataset/train/images
val: ./dataset/val/images
test: ./dataset/test/images

nc: 1
names: ['pothole']

第3步:创建模型训练脚本。

接下来,我们需要编写代码来使用我们的自定义数据集训练模型。之后,我们将继续实现模型的回调函数,这是本教程的唯一目的。现在,在你的项目根目录中创建一个新文件(命名为training.py)。在这个training.py文件中,我们将实现模型训练和回调函数。首先,让我们编写一个用于训练YOLOV8模型的函数:

def train_yolov8_model(config_path, num_epochs, training_result_dir):
        model = YOLO("yolov8x.pt")
        model.add_callback("on_train_start", on_train_start)
        model.add_callback("on_train_epoch_end", on_train_epoch_end)
        model.add_callback("on_train_end", on_train_end)
        model.start_time = datetime.now()
        start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        # Train the model
        model.train(
            data=config_path,
            name="Yolo_Model_Training",
            project=training_result_dir,
            task="detect",
            epochs=num_epochs,
            patience=20,
            batch=16,
            cache=True,
            imgsz=640,
            iou=0.5,
            augment=True,
            degrees=25.0,
            fliplr=0.0,
            lr0=0.0001,
            optimizer="Adam",
            device=device,
        )

注意:函数参数中的config_path是我们之前创建的数据集yaml配置文件。我们稍后将定义的回调函数,就像model.add_callback这样的调用,稍等一下。

接下来,让我们实现回调函数。在这种情况下,我们将要实现的回调函数包括:on_train_start、on_train_epoch_end和on_train_end。on_train_start回调是在模型开始训练时立即触发的回调函数。on_train_epoch_end是在每个epoch结束后立即触发的回调函数。on_train_end是在模型完成训练后触发的回调函数。

实现回调函数

   def on_train_start(trainer):
        start_time = datetime.now()

    def on_train_epoch_end(trainer):
        curr_epoch = trainer.epoch + 1
        text = f"Epoch Number: {curr_epoch}/{trainer.epochs} finished"
        print(text)
        print("-" * 50)

对于on_train_start回调,我们需要追踪模型开始训练的确切时间。你实际上可以在这里实现更复杂的逻辑。对于on_train_epoch_end,我们只是获取了当前epoch并打印出来。这只是一个简单的演示。我们可以在这里实现更复杂的逻辑。例如,如果我们有一个用户正在从中训练模型的前端应用程序,我们可以在每个epoch结束后更新GUI的训练进度条。我们可以在这个函数中实现这个功能。

现在,让我们继续实现本教程的主要逻辑。我们将继续实现on_train_end回调函数。如前所述,此函数仅在模型训练成功完成后触发。在我们的情况下,我们想要发送一个包含模型训练报告的电子邮件给我们的团队工程师。为了实现这一点,首先,让我们编写一个发送电子邮件的函数。我们将使用Gmail发送电子邮件。

以下是发送电子邮件的函数:

 def send_email(
        body,
        from_email=FROM_EMAIL,
        to_emails=RECIPENT_EMAIL,
        subject=subject,
        api=EMAIL_API_KEY,
    ):
        msg = MIMEMultipart()
        msg["From"] = from_email
        msg["To"] = to_emails
        msg["Subject"] = subject

        msg.attach(MIMEText(body, "html"))

        try:
            smtp_server = smtplib.SMTP("smtp.gmail.com", 587)
            smtp_server.starttls()
            smtp_server.login(from_email, api)
            smtp_server.sendmail(from_email, to_emails, msg.as_string())
            smtp_server.quit()
            print("Email sent.")
        except Exception as e:
            print("Email not sent", e)

但请注意,我们需要将诸如EMAIL API KEY、SENDER EMAIL等秘密凭证存储到一个环境文件中。基于此,请在你的项目根目录中创建一个新文件(命名为.env)。在.env文件中,添加以下示例内容。

EMAIL_API_KEY=your Gmail app password goes here
EMAIL_ACCOUNT=your Gmail account which you created app password goes here
RECIPENT_EMAIL=the email address you will be sending the report email goes here.

现在,让我们继续实现回调函数(on_train_end),该函数将在模型训练成功完成后触发发送电子邮件功能。


  def on_train_end(trainer):
        trainer_epoch = trainer.epoch
        trainer_metrics = trainer.metrics
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        end_time = datetime.now()
        time_taken = end_time - start_time
        hours, remainder = divmod(time_taken.total_seconds(), 3600)
        minutes, seconds = divmod(remainder, 60)

        time_taken_str = ""
        if int(hours) > 0:
            time_taken_str += f"{int(hours)} hr "
        if int(minutes) > 0:
            time_taken_str += f"{int(minutes)} mins "
        if int(seconds) > 0:
            time_taken_str += f"{int(seconds)} secs"

        time_taken_str = time_taken_str.strip()

        body = f"""
        <html>
            <head>
                <style>
                    table, th, td {{
                        border: 1px solid black;
                        border-collapse: collapse;
                        padding: 5px;
                    }}
</style>
            </head>
            <body>
                <h1>Training Report</h1>
                <p>Date and Time: {current_time}</p>
                <p>Total Epoch Trained: {trainer_epoch + 1} </p>
                <p>Time Taken to Train Model: {time_taken_str} </p>
                <table>
                    <tr>
                        <th>Metric</th>
                        <th>Value</th>
                    </tr>
                    {''.join([f'<tr><td>{k}</td><td>{v:.2f}</td></tr>' for k, v in trainer_metrics.items()])}
                </table>
            </body>
        </html>
        """

        send_email(body)

以上回调函数将在模型训练完成后向指定收件人发送报告邮件。现在,我们已经编写了所有必要的函数,将它们全部封装在一个名为ModelTraining的类中是一个好主意。所以,我们training.py文件中的完整代码现在应该如下所示:

import os
from datetime import datetime
from dotenv import find_dotenv, load_dotenv
import torch

from ultralytics import YOLO

import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart

load_dotenv(find_dotenv())

EMAIL_API_KEY = os.getenv("EMAIL_API_KEY")
FROM_EMAIL = os.getenv("EMAIL_ACCOUNT")
RECIPIENT_EMAIL = os.getenv("RECIPIENT_EMAIL")
subject = "Model Training Completed"


class ModelTraining:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.start_time = None
        self.end_time = None

    def send_email(
        self,
        body,
        from_email=FROM_EMAIL,
        to_emails=RECIPIENT_EMAIL,
        subject=subject,
        api=EMAIL_API_KEY,
    ):
        msg = MIMEMultipart()
        msg["From"] = from_email
        msg["To"] = to_emails
        msg["Subject"] = subject

        msg.attach(MIMEText(body, "html"))

        try:
            smtp_server = smtplib.SMTP("smtp.gmail.com", 587)
            smtp_server.starttls()
            smtp_server.login(from_email, api)
            smtp_server.sendmail(from_email, to_emails, msg.as_string())
            smtp_server.quit()
            print("Email sent.")
        except Exception as e:
            print("Email not sent", e)

    def on_train_start(self, trainer):
        self.start_time = datetime.now()

    def on_train_epoch_end(self, trainer):
        curr_epoch = trainer.epoch + 1
        text = f"Epoch Number: {curr_epoch}/{trainer.epochs} finished"
        print(text)
        print("-" * 50)

    def on_train_end(self, trainer):
        trainer_epoch = trainer.epoch
        trainer_metrics = trainer.metrics
        current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        self.end_time = datetime.now()
        time_taken = self.end_time - self.start_time
        hours, remainder = divmod(time_taken.total_seconds(), 3600)
        minutes, seconds = divmod(remainder, 60)

        time_taken_str = ""
        if int(hours) > 0:
            time_taken_str += f"{int(hours)} hr "
        if int(minutes) > 0:
            time_taken_str += f"{int(minutes)} mins "
        if int(seconds) > 0:
            time_taken_str += f"{int(seconds)} secs"

        time_taken_str = time_taken_str.strip()

        body = f"""
        <html>
            <head>
                <style>
                    table, th, td {{
                        border: 1px solid black;
                        border-collapse: collapse;
                        padding: 5px;
                    }}
                </style>
            </head>
            <body>
                <h1>Training Report</h1>
                <p>Date and Time: {current_time}</p>
                <p>Total Epochs Trained: {trainer_epoch + 1} </p>
                <p>Time Taken to Train Model: {time_taken_str} </p>
                <table>
                    <tr>
                        <th>Metric</th>
                        <th>Value</th>
                    </tr>
                    {''.join([f'<tr><td>{k}</td><td>{v:.2f}</td></tr>' for k, v in trainer_metrics.items()])}
                </table>
            </body>
        </html>
        """

        self.send_email(body)

    def train_yolov8_model(self, config_path, num_epochs, training_result_dir):
        model = YOLO("yolov8x.pt")
        model.add_callback("on_train_start", self.on_train_start)
        model.add_callback("on_train_epoch_end", self.on_train_epoch_end)
        model.add_callback("on_train_end", self.on_train_end)
        model.start_time = datetime.now()
        start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        # Train the model
        model.train(
            data=config_path,
            name="Yolo_Model_Training",
            project=training_result_dir,
            task="detect",
            epochs=num_epochs,
            patience=20,
            batch=16,
            cache=True,
            imgsz=640,
            iou=0.5,
            augment=True,
            degrees=25.0,
            fliplr=0.0,
            lr0=0.0001,
            optimizer="Adam",
            device=self.device,
        )
        model.end_time = datetime.now()


if __name__ == "__main__":
    model_training = ModelTraining()

    # Load the dataset configuration file
    current_dir = os.path.dirname(os.path.abspath(__file__))
    config_path = os.path.join(current_dir, "data.yaml")

    num_epochs = 40  # Change it to any number of epochs you want.
    training_result_path = "./results"
    os.makedirs(training_result_path, exist_ok=True)
    model_training.train_yolov8_model(config_path, num_epochs, training_result_path)

完整的项目结构应该如下所示:

yolo_with_callback/
│
├── dataset/            # Directory containing dataset files
│
├── env/                # python virtual environment directory
│          
│── .env                # Environment variables file containing secret keys
├── results/            # Directory for storing training results
│
├── data.yaml           # Dataset configuration file
│
├── requirements.txt    # File listing required Python packages
│
└── training.py         # Main script for model training

现在,你已经完成了实现,可以继续运行training.py代码。训练完成后,训练结果报告将发送到指定的收件人邮箱。

责任编辑:赵宁宁 来源: 小白玩转Python
相关推荐

2024-09-05 15:42:34

PyTorch回调日志

2023-12-05 15:44:46

计算机视觉FastAPI

2011-05-20 17:19:25

回调函数

2023-01-11 07:28:49

TensorFlow分类模型

2024-05-23 12:57:59

2017-08-28 21:31:37

TensorFlow深度学习神经网络

2023-01-09 08:00:00

迁移学习机器学习数据集

2022-04-12 08:30:52

回调函数代码调试

2024-11-25 06:25:00

YOLO数据标注目标检测

2024-09-19 16:04:41

YOLO数据标注

2023-06-06 15:42:13

Optuna开源

2024-10-30 16:34:56

2024-01-29 00:24:07

图像模型预训练

2024-11-25 07:00:00

箭头函数JavaScriptReact

2024-10-10 14:56:39

2011-07-27 14:10:43

javascript

2024-09-12 17:19:43

YOLO目标检测深度学习

2020-08-10 15:05:02

机器学习人工智能计算机

2024-10-29 16:18:32

YOLOOpenCV

2009-12-07 14:29:08

PHP array_w
点赞
收藏

51CTO技术栈公众号