ML.NET是微软开发的开源机器学习框架,让.NET开发者能够直接在.NET应用程序中集成机器学习功能。本文将详细介绍如何使用ML.NET实现图像分类,包括环境搭建、数据准备、模型训练等完整流程。
环境准备
- Visual Studio 2022
- .NET 6.0或更高版本
- 需要安装的NuGet包:
Microsoft.ML
Microsoft.ML.Vision
Microsoft.ML.ImageAnalytics
SciSharp.TensorFlow.Redist (版本2.3.1)
图片
项目结构
ImageClassification/
├── Program.cs
├── assets/ # 存放训练图片
│ ├── CD/ # 有裂缝的图片
│ └── UD/ # 无裂缝的图片
└── workspace/ # 存放模型文件
代码实现
一、创建数据模型类
// 原始图像数据类
public class ImageData
{
public string ImagePath { get; set; }
public string Label { get; set; }
}
ImageData 类是用来表示和存储图像的基本信息的数据结构,主要用于数据加载和预处理阶段。它包含两个关键属性:
1.ImagePath
- 存储图像文件的完整路径
- 用于后续加载和访问图像文件
- 是一个字符串类型的属性
2.Label
- 存储图像的分类标签/类别
- 表示图像所属的类别(如示例中的"有裂缝"或"无裂缝")
- 是一个字符串类型的属性
主要用途:
- 数据加载:在从目录加载图像数据时,用于初步组织和存储图像信息
- 数据组织:将文件系统中的图像与其对应的分类标签关联起来
// 模型输入类
public class ModelInput
{
public byte[] Image { get; set; }
public UInt32 LabelAsKey { get; set; }
public string ImagePath { get; set; }
public string Label { get; set; }
}
3.Image (byte[] 类型)
- 存储图像的字节数组表示
- 这是模型训练和预测所必需的输入格式
- 模型需要这种类型的图像数据来进行训练
4.LabelAsKey (UInt32 类型)
- 是 Label 的数值表示形式
- 将分类标签转换为数值形式,因为机器学习模型要求输入采用数值格式
5.ImagePath (string 类型)
- 存储图像的完整路径
- 用于方便访问原始图像文件
6.Label (string 类型)
- 图像所属的类别
- 这是需要预测的目标值
- 用于训练时的标签信息
重要说明:
- 在实际训练和预测中,只有 Image 和 LabelAsKey 这两个属性被用于模型训练和预测
- ImagePath 和 Label 属性主要是为了方便访问和追踪原始数据,不直接参与模型计算
- 这个类是连接原始图像数据和模型训练需求的桥梁,将各种必要的信息整合在一起
// 模型输出类
public class ModelOutput
{
public string ImagePath { get; set; }
public string Label { get; set; }
public string PredictedLabel { get; set; }
}
7.ImagePath (string 类型)
- 存储图像的完整文件路径
- 用于追踪和引用原始图像文件
8.Label (string 类型)
- 存储图像的原始/真实类别标签
- 这是图像实际应该属于的类别
9.PredictedLabel (string 类型)
- 存储模型预测的类别标签
- 这是模型通过分析图像后预测出的类别
重要说明:
- 在实际预测过程中,只有 PredictedLabel 是必需的,因为它包含模型的预测结果
- ImagePath 和 Label 属性主要用于评估和验证目的,方便比较预测结果与实际标签的差异
二、 图像加载工具方法
private static IEnumerable<ImageData> LoadImagesFromDirectory(string folder, bool useFolderNameAsLabel = true)
{
var files = Directory.GetFiles(folder, "*", searchOption: SearchOption.AllDirectories);
foreach (var file in files)
{
if ((Path.GetExtension(file) != ".jpg") && (Path.GetExtension(file) != ".png"))
continue;
var label = Path.GetFileName(file);
if (useFolderNameAsLabel)
label = Directory.GetParent(file).Name;
else
{
for (int index = 0; index < label.Length; index++)
{
if (!char.IsLetter(label[index]))
{
label = label.Substring(0, index);
break;
}
}
}
yield return new ImageData()
{
ImagePath = file,
Label = label
};
}
}
三、主程序实现
class Program
{
static void Main(string[] args)
{
// 初始化ML.NET环境
MLContext mlContext = new MLContext();
// 设置路径
var projectDirectory = Path.GetFullPath(Path.Combine(AppContext.BaseDirectory, "../../../"));
var workspaceRelativePath = Path.Combine(projectDirectory, "workspace");
var assetsRelativePath = Path.Combine(projectDirectory, "assets");
// 加载数据
IEnumerable<ImageData> images = LoadImagesFromDirectory(folder: assetsRelativePath, useFolderNameAsLabel: true);
IDataView imageData = mlContext.Data.LoadFromEnumerable(images);
IDataView shuffledData = mlContext.Data.ShuffleRows(imageData);
// 数据预处理
var preprocessingPipeline = mlContext.Transforms.Conversion.MapValueToKey(
inputColumnName: "Label",
outputColumnName: "LabelAsKey")
.Append(mlContext.Transforms.LoadRawImageBytes(
outputColumnName: "Image",
imageFolder: assetsRelativePath,
inputColumnName: "ImagePath"));
IDataView preProcessedData = preprocessingPipeline
.Fit(shuffledData)
.Transform(shuffledData);
// 数据集分割
var trainSplit = mlContext.Data.TrainTestSplit(data: preProcessedData, testFraction: 0.3);
var validationTestSplit = mlContext.Data.TrainTestSplit(trainSplit.TestSet);
// 配置训练选项
var classifierOptions = new ImageClassificationTrainer.Options()
{
FeatureColumnName = "Image",
LabelColumnName = "LabelAsKey",
ValidationSet = validationTestSplit.TrainSet,
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
TestOnTrainSet = false,
ReuseTrainSetBottleneckCachedValues = true,
ReuseValidationSetBottleneckCachedValues = true,
WorkspacePath = workspaceRelativePath
};
// 定义训练管道
var trainingPipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(classifierOptions)
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
// 训练模型
Console.WriteLine("*** 开始训练模型 ***");
ITransformer trainedModel = trainingPipeline.Fit(trainSplit.TrainSet);
// 进行预测
ClassifySingleImage(mlContext, validationTestSplit.TestSet, trainedModel);
ClassifyImages(mlContext, validationTestSplit.TestSet, trainedModel);
}
}
数据预处理说明
第一步:标签转换
mlContext.Transforms.Conversion.MapValueToKey(
inputColumnName: "Label",
outputColumnName: "LabelAsKey")
- 将字符串类型的标签("Label")转换为数值类型("LabelAsKey")
- 例如:"CD" -> 0, "UD" -> 1
- 这是必需的,因为机器学习模型需要数值形式的标签
- 输入是 ImageData 类中的 Label 属性
- 输出存储在 ModelInput 类的 LabelAsKey 属性中
第二步:图像加载
mlContext.Transforms.LoadRawImageBytes(
outputColumnName: "Image",
imageFolder: assetsRelativePath,
inputColumnName: "ImagePath")
- 将图像文件转换为字节数组格式
- `outputColumnName: "Image"`: 输出到 ModelInput 类的 Image 属性
- `imageFolder: assetsRelativePath`: 指定图像文件所在的根目录
- `inputColumnName: "ImagePath"`: 使用 ImageData 类中的 ImagePath 属性
配置训练选项
1.FeatureColumnName = "Image"
- 指定用作模型输入的列名
- 这里使用"Image"列,它包含图像的字节数组数据
2.LabelColumnName = "LabelAsKey"
- 指定要预测的目标值列名
- 使用"LabelAsKey"列,它是标签的数值表示形式
3.ValidationSet = validationTestSplit.TrainSet
- 指定用于验证的数据集
- 用于在训练过程中评估模型性能
4.Arch = ImageClassificationTrainer.Architecture.ResnetV2101
- 指定使用的预训练模型架构
- 这里使用 ResNet v2 的101层变体
- ResNet是一个预训练模型,可以将图像分为1000个类别
5.MetricsCallback = (metrics) => Console.WriteLine(metrics)
- 用于在训练过程中跟踪和显示训练指标
- 通过控制台输出训练进度和性能指标
6.TestOnTrainSet = false
- 设置是否在训练集上测试模型
- false表示不在训练集上测试,避免过拟合
7.ReuseTrainSetBottleneckCachedValues = true
- 是否重用训练集的瓶颈层计算结果
- true表示缓存并重用这些值,可以显著减少训练时间
- 适用于训练数据不变但需要调整其他参数的情况
8.ReuseValidationSetBottleneckCachedValues = true
- 是否重用验证集的瓶颈层计算结果
- 与上面类似,但作用于验证数据集
9.WorkspacePath = workspaceRelativePath
- 指定存储工作文件的目录路径
- 用于保存计算的瓶颈值和模型的.pb版本
- 便于后续重用和模型部署
这些参数的配置对模型的训练效果和效率有重要影响,可以根据具体需求调整这些参数来优化模型性能。
定义训练管道
- 图像分类训练器
mlContext.MulticlassClassification.Trainers.ImageClassification(classifierOptions)
- 使用多分类分类器进行图像分类
- 基于之前定义的 classifierOptions 配置
- 使用迁移学习方法,基于预训练的 ResNet 模型
- 主要功能:
- 提取图像特征
- 训练分类器
- 生成预测模型
- 预测标签转换
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"))
- 将模型输出的数值预测结果转换回原始标签
- 与前面的 MapValueToKey 操作相反
- 例如:将 0 转回 "CD",1 转回 "UD"
- 确保最终输出是人类可读的标签
整个训练管道的工作流程:
- 接收预处理后的数据(图像字节数组和数值标签)
- 通过深度学习模型进行特征提取和分类
- 将数值预测结果转换为原始标签类别
- 输出最终的分类结果
四、预测方法实现
private static void ClassifySingleImage(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
var predictionEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel);
var image = mlContext.Data.CreateEnumerable<ModelInput>(data, reuseRowObject: true).First();
var prediction = predictionEngine.Predict(image);
Console.WriteLine($"单张图片分类结果:");
Console.WriteLine($"图片: {Path.GetFileName(prediction.ImagePath)}");
Console.WriteLine($"实际类别: {prediction.Label}");
Console.WriteLine($"预测类别: {prediction.PredictedLabel}");
}
private static void ClassifyImages(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
IDataView predictionData = trainedModel.Transform(data);
var predictions = mlContext.Data.CreateEnumerable<ModelOutput>(predictionData, reuseRowObject: true)
.Take(10);
Console.WriteLine("\n批量图片分类结果:");
foreach (var prediction in predictions)
{
Console.WriteLine($"图片: {Path.GetFileName(prediction.ImagePath)}");
Console.WriteLine($"实际类别: {prediction.Label}");
Console.WriteLine($"预测类别: {prediction.PredictedLabel}\n");
}
}
五、执行
训练速度比较慢
图片
图片
图片
实际与预测都是CD。
图片
这里会发现预测与实际是有出入的。
模型优化建议
1.增加训练数据量: 收集更多的样本数据可以提高模型的泛化能力。
2.数据增强:
- 对现有图片进行旋转、翻转、缩放等操作
- 调整亮度、对比度
- 添加噪声
3.调整超参数:
- 增加训练轮数(Epoch)
- 调整学习率
- 尝试不同的批次大小
4.使用不同的预训练模型:
- ResNet不同版本
- Inception
- MobileNet
总结
本文详细介绍了如何使用ML.NET实现图像分类功能。通过使用迁移学习和预训练模型,我们可以快速构建高质量的图像分类应用。ML.NET提供了简单易用的API,让.NET开发者能够方便地将机器学习集成到应用程序中。