一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix) 原创

发布于 2024-12-16 14:24
浏览
0收藏

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

混淆矩阵(Confusion Matrix)是机器学习中评估分类模型性能的重要工具。通过混淆矩阵,可以直观地了解模型在各个类别上的表现,包括正确分类和错误分类的样本数量。

基于混淆矩阵,我们可以计算准确率、精确率、召回率、F1分数以及真正率和假正率等多个评估指标,用于评估分类模型的性能。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

Confusion Matrix

一、混淆矩阵

混淆矩阵(Confusion Matrix)是什么?混淆矩阵是一个表格,用于描述分类模型的预测结果与实际标签之间的关系。

对于一个二分类问题,混淆矩阵是一个2x2的矩阵。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

对于多分类问题,混淆矩阵的大小为类别数乘以类别数。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

混淆矩阵的评估指标有哪些?混淆矩阵可用于计算准确率、精确率、召回率、F1分数以及真正率和假正率等多个评估指标,这些指标共同构成了评估分类模型性能的完整体系。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

1. 准确率(Accuracy)

准确率是模型正确分类的样本数占总样本数的比例。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

对于多分类问题,准确率同样适用,只需将TP、TN、FP、FN替换为对应类别的数量总和。

2. 精确率(Precision)

精确率是针对预测为正类的样本,模型预测正确的比例。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

对于多分类问题,可以计算每个类别的精确率。

3. 召回率(Recall)

召回率是针对实际为正类的样本,模型预测正确的比例。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

同样,对于多分类问题,可以计算每个类别的召回率。

4. F1分数(F1 Score)

F1分数是精确率和召回率的调和平均数,用于综合评估模型的性能。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

对于多分类问题,可以计算每个类别的F1分数,或者计算宏平均(Macro-average)和微平均(Micro-average)F1分数。

5. 真正率(True Positive Rate, TPR)和假正率(False Positive Rate, FPR)

真正率也称为灵敏度(Sensitivity)或召回率(Recall)。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

假正率也称为1-特异度(1-Specificity)。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

二、二分类问题

二分类问题的混淆矩阵是什么?对于二分类问题,混淆矩阵是一个2x2的表格,用于描述分类模型预测结果与实际标签之间的关系,包括真正类(TP)、假正类(FP)、假负类(FN)和真负类(TN)四种情况。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

  • 真正类(TP):模型预测为正,实际也为正。
  • 假正类(FP):模型预测为正,实际为负。
  • 假负类(FN):模型预测为负,实际为正。
  • 真负类(TN):模型预测为负,实际也为负。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

在Python中,使用sklearn.metrics中的confusion_matrix函数计算了实际标签y_true与预测标签y_pred之间的混淆矩阵,并利用seaborn库的heatmap函数以及matplotlib.pyplot库的相关函数对混淆矩阵进行了可视化展示。

from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt


# 假设y_true是实际标签,y_pred是预测标签
y_true = [0, 1, 1, 0, 1, 0, 1, 0, 0, 1]
y_pred = [0, 1, 0, 0, 1, 0, 1, 1, 0, 1]


# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)


# 使用seaborn绘制混淆矩阵
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

三、多分类问题

多分类问题的混淆矩阵是什么?多分类问题的混淆矩阵是一个表格,其行表示实际类别,列表示预测类别,每个单元格的值表示实际类别与预测类别相匹配的样本数量。

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区

在Python中,使用seaborn和matplotlib库,基于给定的实际标签数组y_true和预测标签数组y_pred,生成并可视化了一个三分类问题的混淆矩阵热力图。

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix


# 假设y_true是实际标签数组, y_pred是预测标签数组
y_true = [0, 1, 2, 2, 0, 1, 0, 2, 1, 0]  # 示例实际标签
y_pred = [0, 2, 1, 2, 0, 0, 0, 1, 2, 0]  # 示例预测标签


# 生成混淆矩阵
conf_mat = confusion_matrix(y_true, y_pred)


# 使用seaborn绘制热力图
sns.heatmap(conf_mat, annot=True, cmap='Blues', xticklabels=['Class 0', 'Class 1', 'Class 2'], yticklabels=['Class 0', 'Class 1', 'Class 2'])
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix')
plt.show()

一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix)-AI.x社区



本文转载自公众号架构师带你玩转AI 作者:AllenTang

原文链接:​​https://mp.weixin.qq.com/s/7T25nGz1dJtF_-tuC4w7bA​

©著作权归作者所有,如需转载,请注明出处,否则将追究法律责任
收藏
回复
举报
回复
相关推荐