大家好,我是小寒
今天给大家分享一个机器学习中一个重要的概念,混淆矩阵
混淆矩阵是用于评估分类模型性能的表格。它通过将实际(真实)标签与预测标签进行比较,提供分类问题的预测结果摘要。
混淆矩阵本身是正方形(nxn),其中 n 是模型中的类别数。
对于二元分类问题,混淆矩阵由四个主要部分组成:
- True Positive (TP, 真阳性):实际为正类,预测也为正类的数量。
- True Negative (TN, 真阴性):实际为负类,预测也为负类的数量。
- False Positive (FP, 假阳性):实际为负类,预测却为正类的数量,通常称为"Type I 错误"或"误报"。
- False Negative (FN, 假阴性):实际为正类,预测却为负类的数量,通常称为"Type II 错误"或"漏报"。
图片
为什么要使用混淆矩阵?
混淆矩阵是评估分类模型性能的基本工具。
- 错误分析
它有助于识别模型所犯的错误类型,无论模型更容易出现假阳性还是假阴性,这在应用范围内(例如在医学诊断中)可能至关重要。 - 模型改进
通过分析混淆矩阵,你可以专注于改进模型的特定方面,例如减少误报或提高召回率。 - 类别不平衡处理
在类别不平衡的情况下,一个类别出现的频率高于另一个类别,单凭准确率可能会产生误导。
混淆矩阵可让你更好地了解模型在每个类别中的表现。 - 性能指标计算
分类中的评估指标
1.准确率
准确率是分类任务中最简单的评估指标之一,用来衡量模型预测正确的比例。
准确率的局限性
当处理不平衡的数据集时,一个类别的数量远远超过其他类别,准确率可能会产生误导。
例如,在 95% 的样本属于同一类的数据集中,预测所有实例为多数类的模型的准确率为 95%,但在识别少数类时则无效。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score
# Example true labels (ytest) and predicted labels (ypred)
ytest = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
ypred = [0, 1, 0, 0, 0, 1, 0, 1, 1, 1]
# Calculate confusion matrix
cm = confusion_matrix(ytest, ypred)
# Create a heatmap
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
xticklabels=['1', '0'],
yticklabels=['1', '0'])
# Add labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title('Confusion Matrix')
# Calculate and display accuracy
accuracy = accuracy_score(ytest, ypred)
plt.text(2.3, 1.5, f'Accuracy: {accuracy:.2f}', fontsize=14, color='black', weight='bold')
plt.show()
图片
2.精度
精度用来衡量模型预测为正类的样本中实际为正类的比例。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score
# Example true labels (ytest) and predicted labels (ypred)
ytest = ['spam', 'spam', 'ham', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'ham', 'ham', 'ham']
ypred = ['spam', 'spam', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'ham', 'ham', 'ham', 'ham']
# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['spam', 'ham'])
print("Confusion Matrix:\n", cm)
# Calculate precision
precision = precision_score(ytest, ypred, pos_label='spam')
print("Precision:", precision)
# Create a heatmap for the confusion matrix
plt.figure(figsize=(8, 6))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
xticklabels=['Predicted Spam', 'Predicted Ham'],
yticklabels=['Actual Spam', 'Actual Ham'])
# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nPrecision: {precision:.2f}')
# Show the plot
plt.show()
图片
3.召回率
召回率用来衡量实际为正类的样本中模型预测为正类的比例。
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, recall_score
# Example true labels (ytest) and predicted labels (ypred)
ytest = ['positive', 'positive', 'negative', 'positive', 'negative']
ypred = ['positive', 'negative', 'negative', 'positive', 'positive']
# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['positive', 'negative'])
# Calculate recall
recall = recall_score(ytest, ypred, pos_label='positive')
# Create a heatmap for the confusion matrix
plt.figure(figsize=(6, 4))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
xticklabels=['Predicted Positive', 'Predicted Negative'],
yticklabels=['Actual Positive', 'Actual Negative'])
# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nRecall: {recall:.2f}')
# Show the plot
plt.show()
图片
4.F1-score
F1-score 是精度和召回率的调和平均数,用来综合考虑精度和召回率的平衡。