使用 ML.NET 实现图像分类:从入门到实践

开发 前端
本文详细介绍了如何使用ML.NET实现图像分类功能。通过使用迁移学习和预训练模型,我们可以快速构建高质量的图像分类应用。ML.NET提供了简单易用的API,让.NET开发者能够方便地将机器学习集成到应用程序中。

ML.NET是微软开发的开源机器学习框架,让.NET开发者能够直接在.NET应用程序中集成机器学习功能。本文将详细介绍如何使用ML.NET实现图像分类,包括环境搭建、数据准备、模型训练等完整流程。

环境准备

  1. Visual Studio 2022
  2. .NET 6.0或更高版本
  3. 需要安装的NuGet包:

Microsoft.ML

Microsoft.ML.Vision

Microsoft.ML.ImageAnalytics

SciSharp.TensorFlow.Redist (版本2.3.1)

图片图片

项目结构

ImageClassification/
├── Program.cs
├── assets/           # 存放训练图片
│   ├── CD/          # 有裂缝的图片
│   └── UD/          # 无裂缝的图片  
└── workspace/        # 存放模型文件

代码实现

一、创建数据模型类

// 原始图像数据类
public class ImageData
{
    public string ImagePath { get; set; }
    public string Label { get; set; }
}

ImageData 类是用来表示和存储图像的基本信息的数据结构,主要用于数据加载和预处理阶段。它包含两个关键属性:

1.ImagePath

  • 存储图像文件的完整路径
  • 用于后续加载和访问图像文件
  • 是一个字符串类型的属性

2.Label

  • 存储图像的分类标签/类别
  • 表示图像所属的类别(如示例中的"有裂缝"或"无裂缝")
  • 是一个字符串类型的属性

主要用途:

  • 数据加载:在从目录加载图像数据时,用于初步组织和存储图像信息
  • 数据组织:将文件系统中的图像与其对应的分类标签关联起来
// 模型输入类
public class ModelInput
{
    public byte[] Image { get; set; }
    public UInt32 LabelAsKey { get; set; }
    public string ImagePath { get; set; }
    public string Label { get; set; }
}

3.Image (byte[] 类型)

  • 存储图像的字节数组表示
  • 这是模型训练和预测所必需的输入格式
  • 模型需要这种类型的图像数据来进行训练

4.LabelAsKey (UInt32 类型)

  • 是 Label 的数值表示形式
  • 将分类标签转换为数值形式,因为机器学习模型要求输入采用数值格式

5.ImagePath (string 类型)

  • 存储图像的完整路径
  • 用于方便访问原始图像文件

6.Label (string 类型)

  • 图像所属的类别
  • 这是需要预测的目标值
  • 用于训练时的标签信息

重要说明:

  • 在实际训练和预测中,只有 Image 和 LabelAsKey 这两个属性被用于模型训练和预测
  • ImagePath 和 Label 属性主要是为了方便访问和追踪原始数据,不直接参与模型计算
  • 这个类是连接原始图像数据和模型训练需求的桥梁,将各种必要的信息整合在一起
// 模型输出类
public class ModelOutput
{
    public string ImagePath { get; set; }
    public string Label { get; set; }
    public string PredictedLabel { get; set; }
}

7.ImagePath (string 类型)

  • 存储图像的完整文件路径
  • 用于追踪和引用原始图像文件

8.Label (string 类型)

  • 存储图像的原始/真实类别标签
  • 这是图像实际应该属于的类别

9.PredictedLabel (string 类型)

  • 存储模型预测的类别标签
  • 这是模型通过分析图像后预测出的类别

重要说明:

  • 在实际预测过程中,只有 PredictedLabel 是必需的,因为它包含模型的预测结果
  • ImagePath 和 Label 属性主要用于评估和验证目的,方便比较预测结果与实际标签的差异

二、 图像加载工具方法

private static IEnumerable<ImageData> LoadImagesFromDirectory(string folder, bool useFolderNameAsLabel = true)
{
    var files = Directory.GetFiles(folder, "*", searchOption: SearchOption.AllDirectories);

    foreach (var file in files)
    {
        if ((Path.GetExtension(file) != ".jpg") && (Path.GetExtension(file) != ".png"))
            continue;

        var label = Path.GetFileName(file);
        if (useFolderNameAsLabel)
            label = Directory.GetParent(file).Name;
        else
        {
            for (int index = 0; index < label.Length; index++)
            {
                if (!char.IsLetter(label[index]))
                {
                    label = label.Substring(0, index);
                    break;
                }
            }
        }

        yield return new ImageData()
        {
            ImagePath = file,
            Label = label
        };
    }
}

三、主程序实现

class Program
{
    static void Main(string[] args)
    {
        // 初始化ML.NET环境
        MLContext mlContext = new MLContext();

        // 设置路径
        var projectDirectory = Path.GetFullPath(Path.Combine(AppContext.BaseDirectory, "../../../"));
        var workspaceRelativePath = Path.Combine(projectDirectory, "workspace");
        var assetsRelativePath = Path.Combine(projectDirectory, "assets");

        // 加载数据
        IEnumerable<ImageData> images = LoadImagesFromDirectory(folder: assetsRelativePath, useFolderNameAsLabel: true);
        IDataView imageData = mlContext.Data.LoadFromEnumerable(images);
        IDataView shuffledData = mlContext.Data.ShuffleRows(imageData);

        // 数据预处理
        var preprocessingPipeline = mlContext.Transforms.Conversion.MapValueToKey(
            inputColumnName: "Label",
            outputColumnName: "LabelAsKey")
        .Append(mlContext.Transforms.LoadRawImageBytes(
            outputColumnName: "Image",
            imageFolder: assetsRelativePath,
            inputColumnName: "ImagePath"));

        IDataView preProcessedData = preprocessingPipeline
            .Fit(shuffledData)
            .Transform(shuffledData);

        // 数据集分割
        var trainSplit = mlContext.Data.TrainTestSplit(data: preProcessedData, testFraction: 0.3);
        var validationTestSplit = mlContext.Data.TrainTestSplit(trainSplit.TestSet);

        // 配置训练选项
        var classifierOptions = new ImageClassificationTrainer.Options()
        {
            FeatureColumnName = "Image",
            LabelColumnName = "LabelAsKey",
            ValidationSet = validationTestSplit.TrainSet,
            Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
            MetricsCallback = (metrics) => Console.WriteLine(metrics),
            TestOnTrainSet = false,
            ReuseTrainSetBottleneckCachedValues = true,
            ReuseValidationSetBottleneckCachedValues = true,
            WorkspacePath = workspaceRelativePath
        };

        // 定义训练管道
        var trainingPipeline = mlContext.MulticlassClassification.Trainers
            .ImageClassification(classifierOptions)
            .Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));

        // 训练模型
        Console.WriteLine("*** 开始训练模型 ***");
        ITransformer trainedModel = trainingPipeline.Fit(trainSplit.TrainSet);

        // 进行预测
        ClassifySingleImage(mlContext, validationTestSplit.TestSet, trainedModel);
        ClassifyImages(mlContext, validationTestSplit.TestSet, trainedModel);
    }
}

数据预处理说明

第一步:标签转换

mlContext.Transforms.Conversion.MapValueToKey(
    inputColumnName: "Label",
    outputColumnName: "LabelAsKey")
- 将字符串类型的标签("Label")转换为数值类型("LabelAsKey")
- 例如:"CD" -> 0, "UD" -> 1
- 这是必需的,因为机器学习模型需要数值形式的标签
- 输入是 ImageData 类中的 Label 属性
- 输出存储在 ModelInput 类的 LabelAsKey 属性中

第二步:图像加载

mlContext.Transforms.LoadRawImageBytes(
    outputColumnName: "Image",
    imageFolder: assetsRelativePath,
    inputColumnName: "ImagePath")
- 将图像文件转换为字节数组格式
- `outputColumnName: "Image"`: 输出到 ModelInput 类的 Image 属性
- `imageFolder: assetsRelativePath`: 指定图像文件所在的根目录
- `inputColumnName: "ImagePath"`: 使用 ImageData 类中的 ImagePath 属性

配置训练选项

1.FeatureColumnName = "Image"

  • 指定用作模型输入的列名
  • 这里使用"Image"列,它包含图像的字节数组数据

2.LabelColumnName = "LabelAsKey"

  • 指定要预测的目标值列名
  • 使用"LabelAsKey"列,它是标签的数值表示形式

3.ValidationSet = validationTestSplit.TrainSet

  • 指定用于验证的数据集
  • 用于在训练过程中评估模型性能

4.Arch = ImageClassificationTrainer.Architecture.ResnetV2101

  • 指定使用的预训练模型架构
  • 这里使用 ResNet v2 的101层变体
  • ResNet是一个预训练模型,可以将图像分为1000个类别

5.MetricsCallback = (metrics) => Console.WriteLine(metrics)

  • 用于在训练过程中跟踪和显示训练指标
  • 通过控制台输出训练进度和性能指标

6.TestOnTrainSet = false

  • 设置是否在训练集上测试模型
  • false表示不在训练集上测试,避免过拟合

7.ReuseTrainSetBottleneckCachedValues = true

  • 是否重用训练集的瓶颈层计算结果
  • true表示缓存并重用这些值,可以显著减少训练时间
  • 适用于训练数据不变但需要调整其他参数的情况

8.ReuseValidationSetBottleneckCachedValues = true

  • 是否重用验证集的瓶颈层计算结果
  • 与上面类似,但作用于验证数据集

9.WorkspacePath = workspaceRelativePath

  • 指定存储工作文件的目录路径
  • 用于保存计算的瓶颈值和模型的.pb版本
  • 便于后续重用和模型部署

这些参数的配置对模型的训练效果和效率有重要影响,可以根据具体需求调整这些参数来优化模型性能。

定义训练管道

  • 图像分类训练器
mlContext.MulticlassClassification.Trainers.ImageClassification(classifierOptions)
- 使用多分类分类器进行图像分类
- 基于之前定义的 classifierOptions 配置
- 使用迁移学习方法,基于预训练的 ResNet 模型
- 主要功能:
    - 提取图像特征
    - 训练分类器
    - 生成预测模型
  • 预测标签转换
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"))
- 将模型输出的数值预测结果转换回原始标签
- 与前面的 MapValueToKey 操作相反
- 例如:将 0 转回 "CD",1 转回 "UD"
- 确保最终输出是人类可读的标签

整个训练管道的工作流程:

  • 接收预处理后的数据(图像字节数组和数值标签)
  • 通过深度学习模型进行特征提取和分类
  • 将数值预测结果转换为原始标签类别
  • 输出最终的分类结果

四、预测方法实现

private static void ClassifySingleImage(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
    var predictionEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel);
    var image = mlContext.Data.CreateEnumerable<ModelInput>(data, reuseRowObject: true).First();
    var prediction = predictionEngine.Predict(image);

    Console.WriteLine($"单张图片分类结果:");
    Console.WriteLine($"图片: {Path.GetFileName(prediction.ImagePath)}");
    Console.WriteLine($"实际类别: {prediction.Label}");
    Console.WriteLine($"预测类别: {prediction.PredictedLabel}");
}

private static void ClassifyImages(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
    IDataView predictionData = trainedModel.Transform(data);
    var predictions = mlContext.Data.CreateEnumerable<ModelOutput>(predictionData, reuseRowObject: true)
        .Take(10);

    Console.WriteLine("\n批量图片分类结果:");
    foreach (var prediction in predictions)
    {
        Console.WriteLine($"图片: {Path.GetFileName(prediction.ImagePath)}");
        Console.WriteLine($"实际类别: {prediction.Label}");
        Console.WriteLine($"预测类别: {prediction.PredictedLabel}\n");
    }
}

五、执行

训练速度比较慢

图片图片

图片图片

图片图片

实际与预测都是CD。

图片图片

这里会发现预测与实际是有出入的。

模型优化建议

1.增加训练数据量: 收集更多的样本数据可以提高模型的泛化能力。

2.数据增强:

  • 对现有图片进行旋转、翻转、缩放等操作
  • 调整亮度、对比度
  • 添加噪声

3.调整超参数:

  • 增加训练轮数(Epoch)
  • 调整学习率
  • 尝试不同的批次大小

4.使用不同的预训练模型:

  • ResNet不同版本
  • Inception
  • MobileNet

总结

本文详细介绍了如何使用ML.NET实现图像分类功能。通过使用迁移学习和预训练模型,我们可以快速构建高质量的图像分类应用。ML.NET提供了简单易用的API,让.NET开发者能够方便地将机器学习集成到应用程序中。

责任编辑:武晓燕 来源: 技术老小子
相关推荐

2024-07-03 10:09:29

2025-01-07 08:42:54

2024-03-18 08:38:34

ML.NET机器学习开源

2023-11-29 21:21:57

微软ML.NET 3.0机器学习

2020-11-18 18:21:49

.Net 5大数据机器学习

2023-12-26 08:40:06

分类算法数据分析Python

2023-01-05 16:51:04

机器学习人工智能

2023-11-07 14:30:28

Python开发

2020-09-23 07:45:32

Docker前端

2019-09-02 13:57:07

Helm Chart工具Kubernetes

2021-11-24 22:42:15

WorkManagerAPI

2017-06-26 09:15:39

SQL数据库基础

2021-02-18 09:06:39

数据访问者模式

2025-01-07 14:42:09

2023-11-29 14:47:47

微软ML.NET 3.0

2012-02-29 00:49:06

Linux学习

2024-04-11 14:00:28

2013-06-06 13:42:48

OSPF入门配置

2010-02-06 15:31:18

ibmdwAndroid

2009-07-22 14:55:16

ibmdwAndroid
点赞
收藏

51CTO技术栈公众号