sklearn(五)计算acc:使用metrics.accuracy 您所在的位置:网站首页 多标签分类器的作用是 sklearn(五)计算acc:使用metrics.accuracy

sklearn(五)计算acc:使用metrics.accuracy

2024-07-16 18:54| 来源: 网络整理| 查看: 265

1.acc计算原理

sklearn中accuracy_score函数计算了准确率。

在二分类或者多分类中,预测得到的label,跟真实label比较,计算准确率。

在multilabel(多标签问题)分类中,该函数会返回子集的准确率。如果对于一个样本来说,必须严格匹配真实数据集中的label,整个集合的预测标签返回1.0;否则返回0.0.

2.acc的不适用场景:

在正负样本不平衡的情况下,准确率这个评价指标有很大的缺陷。比如在互联网广告里面,点击的数量是很少的,一般只有千分之几,如果用acc,即使全部预测成负类(不点击)acc也有 99% 以上,没有意义。因此,单纯靠准确率来评价一个算法模型是远远不够科学全面的。在类别不平衡没那么太严重时,该指标具有一定的参考意义。

3.metrics.accuracy_score()的使用方法

不管是二分类还是多分类,还是多标签问题,计算公式都为:

这里写图片描述

只是在多标签问题中,TP、TN要求更加严格,必须严格匹配真实数据集中的label。

sklearn.metrics.accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None)

输入参数:

y_true:真是标签。二分类和多分类情况下是一列,多标签情况下是标签的索引。

y_pred:预测标签。二分类和多分类情况下是一列,多标签情况下是标签的索引。

normalize:bool, optional (default=True),如果是false,正确分类的样本的数目(int);如果为true,返回正确分类的样本的比例,必须严格匹配真实数据集中的label,才为1,否则为0。

sample_weight:array-like of shape (n_samples,), default=None。Sample weights.

输出:

如果normalize == True,返回正确分类的样本的比例,否则返回正确分类的样本的数目(int)。

4.例子 

举一个多标签的例子,这里假设有21个标签。

数据格式:预测label是有阈值硬阶段得来,比如当预测得分大于0.5,则在这个索引下的label为1,否则为0。

idlabelpred_labelpred_label_scorepred_scores(模型输出的scores)1559592780[3][3][0.9060243964195251][0.03569700941443443, 0.025016790255904198, 0.010681516490876675, 0.9060243964195251, 0.03405195102095604, 0.01652703806757927, 0.01057326141744852, 0.015285834670066833, 0.03219904750585556, 0.01710071600973606, 0.015052232891321182, 0.012746844440698624, 0.009399563074111938, 0.012753037735819817, 0.008887830190360546, 0.011201461777091026, 0.013154321350157261, 0.010007181204855442, 0.015232570469379425, 0.011832496151328087, 0.014289622195065022]1559950270[3][][][0.0441354475915432, 0.07238972187042236, 0.011645170859992504, 0.007589259184896946, 0.25604453682899475, 0.08702245354652405, 0.27572867274284363, 0.00486581027507782, 0.01071715448051691, 0.010638655163347721, 0.005942077841609716, 0.03388604149222374, 0.003174690529704094, 0.006336248945444822, 0.007447054609656334, 0.004069846123456955, 0.06864038109779358, 0.003221432212740183, 0.010166178457438946, 0.014550245366990566, 0.018491217866539955]1559394894[3][3][0.6821054816246033][0.2968560457229614, 0.0307493656873703, 0.005526685621589422, 0.6821054816246033, 0.019207751378417015, 0.011433916166424751, 0.00833720900118351, 0.011756493709981441, 0.028093582019209862, 0.008476401679217815, 0.00896463356912136, 0.007736032363027334, 0.006790427025407553, 0.009148293174803257, 0.006993972696363926, 0.006845239549875259, 0.008285323157906532, 0.005908709950745106, 0.009022236801683903, 0.008929350413382053, 0.019131703302264214]1559782048[3][3][0.9018600583076477][0.04472490772604942, 0.0243248138576746, 0.011095968075096607, 0.9018600583076477, 0.02759535051882267, 0.01639750227332115, 0.010229885578155518, 0.01442675106227398, 0.03185756132006645, 0.01614650897681713, 0.014211165718734264, 0.011741148307919502, 0.00937943160533905, 0.013027109205722809, 0.008298314176499844, 0.010878310538828373, 0.012541105970740318, 0.009680655784904957, 0.014786235056817532, 0.01098882406949997, 0.014351315796375275]1560480983[3][6][0.5473132729530334][0.07873011380434036, 0.02117929421365261, 0.00462101586163044, 0.007679674308747053, 0.006423152983188629, 0.003737745573744178, 0.5473132729530334, 0.010648651979863644, 0.2306162267923355, 0.033958908170461655, 0.009718521498143673, 0.03945154696702957, 0.0667884573340416, 0.010746568441390991, 0.008459050208330154, 0.012853718362748623, 0.006122407037764788, 0.005631749518215656, 0.006334631238132715, 0.01488021295517683, 0.08340618759393692]

demo:

目的:计算标签的整体acc、precision、recall。

如果想计算某一个类别的precision和recall,则在评价函数中加上这个参数:pos_label = [4],这里的4表示索引的第4列。

def calculate_acc_multi_label(read_path, sheet_name): workbook = xlrd.open_workbook(read_path) # 打开工作簿 sheets = workbook.sheet_names() # 获取工作簿中的所有表格 worksheet = workbook.sheet_by_name(sheets[0]) # 获取工作簿中所有表格中的的第一个表格 print(worksheet.nrows) print(worksheet.ncols) true_label = [] pred_label = [] for i in range(1, 501): label_str = worksheet.cell_value(i, 1) label = [0 for x in range(0, 21)] label_str = label_str[1:-1] label_list = label_str.split(',') for j in label_list: label[int(j)] = 1 true_label.append(label) pred_list = worksheet.cell_value(i, 2) pred_lab = [0 for x in range(0, 21)] # print('--length of pred: ', len(pred_list)) pred_list = pred_list[1:-1] print('---index: {0} pred_list {1}: '.format(i, pred_list)) if pred_list != '': pred_list = pred_list.split(',') for g in pred_list: pred_lab[int(g)] = 1 pred_label.append(pred_lab) acc = metrics.accuracy_score(true_label, pred_label) print('--acc:', acc) # acc_list = hamming_score(true_label, pred_label) # hamming = np.mean(acc_list) # print('--hamming:', hamming) precision = metrics.precision_score(true_label, pred_label, average='micro') print('--precision:', precision) recall = metrics.recall_score(true_label, pred_label, average='micro') print('--recall:', recall) f1 = metrics.f1_score(np.array(true_label), np.array(pred_label), average='micro') print('--f1:', f1) mcm = metrics.multilabel_confusion_matrix(true_label, pred_label) tn = mcm[:, 0, 0] tp = mcm[:, 1, 1] fn = mcm[:, 1, 0] fp = mcm[:, 0, 1] print('tp: {0} fn: {1} fp: {2}'.format(tp, fn, fp)) sum_tp = sum(tp) sum_fn = sum(fn) sum_fp = sum(fp) print('sum_tp: {0} sum_fn: {1} sum_fp: {2}'.format(sum_tp, sum_fn, sum_fp)) recall_list = tp / (tp + fn) print('--recall_list', recall_list) precision_list = tp / (tp + fp) print('--precision_list', precision_list) print('--precision_list length', len(precision_list)) print('---mcm :', mcm) if __name__ == '__main__': save_path = './multi_label_all_0.5_2.xlsx' sheet_name = 'predict' calculate_acc_multi_label(save_path, sheet_name)

这里打印了多标签的混淆矩阵,用来验证acc、precision、recall是怎么计算得到的,运行后,返回结果如下:

--acc: 0.704 --precision: 0.7960396039603961 --recall: 0.7052631578947368 --f1: 0.747906976744186 tp: [ 2 2 0 393 4 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0] fn: [ 0 14 2 107 6 2 0 13 16 0 5 1 0 0 1 0 1 0 0 0 0] fp: [45 18 4 0 13 0 13 1 2 5 0 1 0 0 0 1 0 0 0 0 0] sum_tp: 402 sum_fn: 168 sum_fp: 103 --recall_list [1. 0.125 0. 0.786 0.4 0. nan 0.07142857 0. nan 0. 0. nan nan 0. nan 0. nan nan nan nan] --precision_list [0.04255319 0.1 0. 1. 0.23529412 nan 0. 0.5 0. 0. nan 0. nan nan nan 0. nan nan nan nan nan] --precision_list length 21 ---mcm : [[[453 45] [ 0 2]] [[466 18] [ 14 2]] [[494 4] [ 2 0]] [[ 0 0] [107 393]] [[477 13] [ 6 4]] [[498 0] [ 2 0]] [[487 13] [ 0 0]] [[485 1] [ 13 1]] [[482 2] [ 16 0]] [[495 5] [ 0 0]] [[495 0] [ 5 0]] [[498 1] [ 1 0]] [[500 0] [ 0 0]] [[500 0] [ 0 0]] [[499 0] [ 1 0]] [[499 1] [ 0 0]] [[499 0] [ 1 0]] [[500 0] [ 0 0]] [[500 0] [ 0 0]] [[500 0] [ 0 0]] [[500 0] [ 0 0]]]

可以用过混淆矩阵计算acc 、precision、recall等指标。

 



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有