终于把机器学习中的混淆矩阵搞懂了!

人工智能 机器学习
混淆矩阵是用于评估分类模型性能的表格。它通过将实际(真实)标签与预测标签进行比较,提供分类问题的预测结果摘要。混淆矩阵本身是正方形(nxn),其中 n 是模型中的类别数。

大家好,我是小寒

今天给大家分享一个机器学习中一个重要的概念,混淆矩阵

混淆矩阵是用于评估分类模型性能的表格。它通过将实际(真实)标签与预测标签进行比较,提供分类问题的预测结果摘要。

混淆矩阵本身是正方形(nxn),其中 n 是模型中的类别数。

对于二元分类问题,混淆矩阵由四个主要部分组成:

  • True Positive (TP, 真阳性):实际为正类,预测也为正类的数量。
  • True Negative (TN, 真阴性):实际为负类,预测也为负类的数量。
  • False Positive (FP, 假阳性):实际为负类,预测却为正类的数量,通常称为"Type I 错误"或"误报"。
  • False Negative (FN, 假阴性):实际为正类,预测却为负类的数量,通常称为"Type II 错误"或"漏报"。

图片图片

为什么要使用混淆矩阵?

混淆矩阵是评估分类模型性能的基本工具。

  1. 错误分析
    它有助于识别模型所犯的错误类型,无论模型更容易出现假阳性还是假阴性,这在应用范围内(例如在医学诊断中)可能至关重要。
  2. 模型改进
    通过分析混淆矩阵,你可以专注于改进模型的特定方面,例如减少误报或提高召回率。
  3. 类别不平衡处理
    在类别不平衡的情况下,一个类别出现的频率高于另一个类别,单凭准确率可能会产生误导。
    混淆矩阵可让你更好地了解模型在每个类别中的表现。
  4. 性能指标计算

分类中的评估指标

1.准确率

准确率是分类任务中最简单的评估指标之一,用来衡量模型预测正确的比例。

准确率的局限性

当处理不平衡的数据集时,一个类别的数量远远超过其他类别,准确率可能会产生误导。

例如,在 95% 的样本属于同一类的数据集中,预测所有实例为多数类的模型的准确率为 95%,但在识别少数类时则无效。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, accuracy_score

# Example true labels (ytest) and predicted labels (ypred)
ytest = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
ypred = [0, 1, 0, 0, 0, 1, 0, 1, 1, 1]

# Calculate confusion matrix
cm = confusion_matrix(ytest, ypred)

# Create a heatmap
plt.figure(figsize=(6, 4))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
            xticklabels=['1', '0'],
            yticklabels=['1', '0'])

# Add labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title('Confusion Matrix')

# Calculate and display accuracy
accuracy = accuracy_score(ytest, ypred)
plt.text(2.3, 1.5, f'Accuracy: {accuracy:.2f}', fontsize=14, color='black', weight='bold')

plt.show()

图片图片

2.精度

精度用来衡量模型预测为正类的样本中实际为正类的比例。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score

# Example true labels (ytest) and predicted labels (ypred)
ytest = ['spam', 'spam', 'ham', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'ham', 'ham', 'ham']
ypred = ['spam', 'spam', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'spam', 'spam', 'ham', 'ham', 'ham', 'ham', 'ham']

# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['spam', 'ham'])
print("Confusion Matrix:\n", cm)

# Calculate precision
precision = precision_score(ytest, ypred, pos_label='spam')
print("Precision:", precision)

# Create a heatmap for the confusion matrix
plt.figure(figsize=(8, 6))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
                 xticklabels=['Predicted Spam', 'Predicted Ham'],
                 yticklabels=['Actual Spam', 'Actual Ham'])

# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nPrecision: {precision:.2f}')

# Show the plot
plt.show()

图片图片

3.召回率

召回率用来衡量实际为正类的样本中模型预测为正类的比例。

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, recall_score

# Example true labels (ytest) and predicted labels (ypred)
ytest = ['positive', 'positive', 'negative', 'positive', 'negative']
ypred = ['positive', 'negative', 'negative', 'positive', 'positive']

# Calculate the confusion matrix
cm = confusion_matrix(ytest, ypred, labels=['positive', 'negative'])


# Calculate recall
recall = recall_score(ytest, ypred, pos_label='positive')


# Create a heatmap for the confusion matrix
plt.figure(figsize=(6, 4))
ax = sns.heatmap(cm, annot=True, fmt='d', cmap='viridis', cbar=False,
                 xticklabels=['Predicted Positive', 'Predicted Negative'],
                 yticklabels=['Actual Positive', 'Actual Negative'])

# Set labels and title
plt.xlabel('Predicted Classes')
plt.ylabel('Actual Classes')
plt.title(f'Confusion Matrix\nRecall: {recall:.2f}')

# Show the plot
plt.show()

图片图片

4.F1-score

F1-score 是精度和召回率的调和平均数,用来综合考虑精度和召回率的平衡。

责任编辑:武晓燕 来源: 程序员学长
相关推荐

2024-09-18 16:42:58

机器学习评估指标模型

2024-10-14 14:02:17

机器学习评估指标人工智能

2024-11-05 12:56:06

机器学习函数MSE

2024-10-08 15:09:17

2024-10-28 00:00:10

机器学习模型程度

2024-10-30 08:23:07

2024-10-28 15:52:38

机器学习特征工程数据集

2024-10-08 10:16:22

2024-11-25 08:20:35

2024-08-01 08:41:08

2024-12-03 08:16:57

2024-10-16 07:58:48

2024-09-23 09:12:20

2024-07-17 09:32:19

2024-11-21 10:07:40

2024-10-31 10:00:39

注意力机制核心组件

2024-12-02 13:28:44

2024-12-02 01:10:04

神经网络自然语言DNN

2024-07-24 08:04:24

神经网络激活函数

2024-11-07 08:26:31

神经网络激活函数信号
点赞
收藏

51CTO技术栈公众号