一文彻底搞懂机器学习 - 混淆矩阵(Confusion Matrix) 原创
混淆矩阵(Confusion Matrix)是机器学习中评估分类模型性能的重要工具。通过混淆矩阵,可以直观地了解模型在各个类别上的表现,包括正确分类和错误分类的样本数量。
基于混淆矩阵,我们可以计算准确率、精确率、召回率、F1分数以及真正率和假正率等多个评估指标,用于评估分类模型的性能。
Confusion Matrix
一、混淆矩阵
混淆矩阵(Confusion Matrix)是什么?混淆矩阵是一个表格,用于描述分类模型的预测结果与实际标签之间的关系。
对于一个二分类问题,混淆矩阵是一个2x2的矩阵。
对于多分类问题,混淆矩阵的大小为类别数乘以类别数。
混淆矩阵的评估指标有哪些?混淆矩阵可用于计算准确率、精确率、召回率、F1分数以及真正率和假正率等多个评估指标,这些指标共同构成了评估分类模型性能的完整体系。
1. 准确率(Accuracy)
准确率是模型正确分类的样本数占总样本数的比例。
对于多分类问题,准确率同样适用,只需将TP、TN、FP、FN替换为对应类别的数量总和。
2. 精确率(Precision)
精确率是针对预测为正类的样本,模型预测正确的比例。
对于多分类问题,可以计算每个类别的精确率。
3. 召回率(Recall)
召回率是针对实际为正类的样本,模型预测正确的比例。
同样,对于多分类问题,可以计算每个类别的召回率。
4. F1分数(F1 Score)
F1分数是精确率和召回率的调和平均数,用于综合评估模型的性能。
对于多分类问题,可以计算每个类别的F1分数,或者计算宏平均(Macro-average)和微平均(Micro-average)F1分数。
5. 真正率(True Positive Rate, TPR)和假正率(False Positive Rate, FPR)
真正率也称为灵敏度(Sensitivity)或召回率(Recall)。
假正率也称为1-特异度(1-Specificity)。
二、二分类问题
二分类问题的混淆矩阵是什么?对于二分类问题,混淆矩阵是一个2x2的表格,用于描述分类模型预测结果与实际标签之间的关系,包括真正类(TP)、假正类(FP)、假负类(FN)和真负类(TN)四种情况。
- 真正类(TP):模型预测为正,实际也为正。
- 假正类(FP):模型预测为正,实际为负。
- 假负类(FN):模型预测为负,实际为正。
- 真负类(TN):模型预测为负,实际也为负。
在Python中,使用sklearn.metrics中的confusion_matrix函数计算了实际标签y_true与预测标签y_pred之间的混淆矩阵,并利用seaborn库的heatmap函数以及matplotlib.pyplot库的相关函数对混淆矩阵进行了可视化展示。
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# 假设y_true是实际标签,y_pred是预测标签
y_true = [0, 1, 1, 0, 1, 0, 1, 0, 0, 1]
y_pred = [0, 1, 0, 0, 1, 0, 1, 1, 0, 1]
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 使用seaborn绘制混淆矩阵
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()
三、多分类问题
多分类问题的混淆矩阵是什么?多分类问题的混淆矩阵是一个表格,其行表示实际类别,列表示预测类别,每个单元格的值表示实际类别与预测类别相匹配的样本数量。
在Python中,使用seaborn和matplotlib库,基于给定的实际标签数组y_true和预测标签数组y_pred,生成并可视化了一个三分类问题的混淆矩阵热力图。
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 假设y_true是实际标签数组, y_pred是预测标签数组
y_true = [0, 1, 2, 2, 0, 1, 0, 2, 1, 0] # 示例实际标签
y_pred = [0, 2, 1, 2, 0, 0, 0, 1, 2, 0] # 示例预测标签
# 生成混淆矩阵
conf_mat = confusion_matrix(y_true, y_pred)
# 使用seaborn绘制热力图
sns.heatmap(conf_mat, annot=True, cmap='Blues', xticklabels=['Class 0', 'Class 1', 'Class 2'], yticklabels=['Class 0', 'Class 1', 'Class 2'])
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix')
plt.show()
本文转载自公众号架构师带你玩转AI 作者:AllenTang