分类模型confusion matrix混淆矩阵可视化 您所在的位置:网站首页 混淆矩阵怎么看准确率 分类模型confusion matrix混淆矩阵可视化

分类模型confusion matrix混淆矩阵可视化

2024-07-13 07:31| 来源: 网络整理| 查看: 265

        之前写过一篇关于在scikit-learn工具包中,可视化estimator分类模型分类结果的confusion matrix混淆矩阵可视化的方法,具体可以参考看这里,看这里。今天这篇介绍一下如何使用scikit-learn工具中提供的相关方法,可视化其他任意框架(比如深度学习框架)的分类模型预测结果的混淆矩阵。

        下面先说一下几个关键步骤:

1、确定类别列表,类别列表和one-hot的编码顺序一致,这里使用cifar-10的类别列表作为演示的例子。

classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"

2、准备好样本的真实label,这里我手动构造一个1000个样本的label,每一类100个。

# 生成数据集的GT标签 gt_labels = np.zeros(1000).reshape(10, -1) for i in range(10): gt_labels[i] = i gt_labels = gt_labels.reshape(1, -1).squeeze() print("gt_labels.shape : {}".format(gt_labels.shape)) print("gt_labels : {}".format(gt_labels[::5]))

3、准备好样本的预测label,这里我也手动构造这1000个样本的预测label,构造时才用了一点规则,构造出来的预测结果保证从第0类到第9类的预测准确率是逐渐降低的。

# 生成数据集的预测标签 pred_labels = np.zeros(1000).reshape(10, -1) for i in range(10): # 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值 # 这样生成的预测准确率从0到9逐渐递减 pred_labels[i] = np.random.randint(0, i + 1, 100) pred_labels = pred_labels.reshape(1, -1).squeeze() print("pred_labels.shape : {}".format(pred_labels.shape)) print("pred_labels : {}".format(pred_labels[::5]))

4、计算真是label和预测label的混淆矩阵,直接调用scikit-learn中的confusion_matrix方法

# 使用sklearn工具中confusion_matrix方法计算混淆矩阵 confusion_mat = confusion_matrix(gt_labels, pred_labels) print("confusion_mat.shape : {}".format(confusion_mat.shape)) print("confusion_mat : {}".format(confusion_mat))

5、混淆矩阵可视化,在scikit-learn工具中有一个plot_confusion_matrix方法可以可视化sklearn训练的模型estimator的混淆矩阵,具体参数如下:

        但是,现在的问题是我们使用的是别的框架训练的模型,也就没有这个estimator参数可以供sklearn使用,怎么办?

        我们看一下plot_confusion_matrix函数的代码可以发现,他其实内部调用了以下方法:

         那么,我们也仿照这个调用方式来写一下试试,代码如下:

# 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes) disp.plot( include_values=True, # 混淆矩阵每个单元格上显示具体数值 cmap="viridis", # 不清楚啥意思,没研究,使用的sklearn中的默认值 ax=None, # 同上 xticks_rotation="horizontal", # 同上 values_format="d" # 显示的数值格式 )

 6、将以上代码整合一下,输入数据的真实label和预测label,就可以可视化混淆矩阵了,并且不仅局限于评估scikit-learn的estimator,可以适用于所有框架的输出结果,完整代码如下:

import numpy as np from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay from matplotlib import pyplot as plt classes = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"] # 生成数据集的GT标签 gt_labels = np.zeros(1000).reshape(10, -1) for i in range(10): gt_labels[i] = i gt_labels = gt_labels.reshape(1, -1).squeeze() print("gt_labels.shape : {}".format(gt_labels.shape)) print("gt_labels : {}".format(gt_labels[::5])) # 生成数据集的预测标签 pred_labels = np.zeros(1000).reshape(10, -1) for i in range(10): # 标签生成规则:对于真值类别编号为i的数据,生成的预测类别编号为[0, i-1]之间的随机值 # 这样生成的预测准确率从0到9逐渐递减 pred_labels[i] = np.random.randint(0, i + 1, 100) pred_labels = pred_labels.reshape(1, -1).squeeze() print("pred_labels.shape : {}".format(pred_labels.shape)) print("pred_labels : {}".format(pred_labels[::5])) # 使用sklearn工具中confusion_matrix方法计算混淆矩阵 confusion_mat = confusion_matrix(gt_labels, pred_labels) print("confusion_mat.shape : {}".format(confusion_mat.shape)) print("confusion_mat : {}".format(confusion_mat)) # 使用sklearn工具包中的ConfusionMatrixDisplay可视化混淆矩阵,参考plot_confusion_matrix disp = ConfusionMatrixDisplay(confusion_matrix=confusion_mat, display_labels=classes) disp.plot( include_values=True, # 混淆矩阵每个单元格上显示具体数值 cmap="viridis", # 不清楚啥意思,没研究,使用的sklearn中的默认值 ax=None, # 同上 xticks_rotation="horizontal", # 同上 values_format="d" # 显示的数值格式 ) plt.show()

7、混淆矩阵的可视化结果

        上图中的可视化结果符合我们在生成预测label标签时使用的规则,就是对于每个类别 i 的预测结果是0-i之间的随机值,这样的话,每个类别的预测误差只会出现在类别编号比它小的部分,也就是上图中展示的下三角矩阵。

        在混淆矩阵中,横轴上的标签标示样本的预测label,纵轴上的标签标示样本的实际label。所以,对角线上的数字表示预测label和真是label一致的数量,也就是预测正确的数量。对于其他位置的数字就表示预测错误的,举个例子,比如第2行、第1列,也就是对应着(airplane, automobile)位置的数字51,表示有51个真实label为automobile的样本被预测为了airplane。

        通过可视化的混淆矩阵,模型的误差,以及效果分类不好的类别,以及为什么不好,以及容易和哪个类之间出现误识别就一目了然了。

参考:https://blog.csdn.net/cxx654/article/details/107296343



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有