修改onnx模型输出示例 您所在的位置:网站首页 onnx怎么读 修改onnx模型输出示例

修改onnx模型输出示例

2024-02-19 06:58| 来源: 网络整理| 查看: 265

前言

在这里插入图片描述 如图是netron(github链接)软件中打开的onnx模型,可以看到右边模型的最终输出结果是分类值predict_0而非概率值,那么如何获取中间过程的概率值,或者说怎么把右边的图砍掉一截变成左边的图呢?

代码

读入模型

import onnx onnx_model = onnx.load("xxx.onnx") graph = onnx_model.graph

首先以图的形式读入你的模型,图一般包括node(节点),initializer(初始化),input(输入),output(输出)四部分,全部打印出来的话非常长,在这里我们主要涉及到删除节点和修改输出两部分。

查看output 在这里插入图片描述 打印output,发现其中的主要信息为name和elem_type两部分。 使用netron选中上一层节点即identity,查看详细信息。 在这里插入图片描述 可以看到上一层节点的输出名即为output的输出名,这两个需要保持一致。 同时elem_type=7,表示输出类型为int64(参照我上一篇文章),这个也需要和上一层节点输出的类型保持一致。

删除node节点

nodes = graph.node for i in range(len(nodes)): print(i,nodes[i])

由于是修改输出,我们只关心最后的几个节点。 在这里插入图片描述 argmax函数不用说了,是经典转化为0,1输出的函数,那么要取概率值肯定在argmax之前,mul是相乘,乘什么我们需要用netron打开模型查看mul节点如下: 在这里插入图片描述 乘0,1矩阵,应该是要转为2列方便进行后续处理,那么我们只要取A的值,也就是上一层节点输出的"add_result_0"就行了。 也就是说,删除序号为261,262,263,264的节点:

graph.node.remove(nodes[264]) graph.node.remove(nodes[263]) graph.node.remove(nodes[262]) graph.node.remove(nodes[261])

倒着顺序删除,再次打印节点: 在这里插入图片描述 完成删除。

修改output

graph.output[0].name = 'add_result_0' graph.output[0].type.tensor_type.elem_type = 1

由于add_result_0输出的值为float32,因此需要在output中修改elem_type为对应的数据类型,1为float32,更多类型查看我上一篇文章。

结果比较

onnx.save(onnx_model, 'modify_xxx.onnx')

先保存模型,再调用onnx runtime进行模型调用:

import onnxruntime as rt sess = rt.InferenceSession("modify_xxx.onnx") onnx_pred = sess.run() #具体里面填什么根据你的模型填,此处为伪代码

在这里插入图片描述 和未修改之前的模型比较: 在这里插入图片描述 发现修改后的模型概率值大于0的都被分类成了1,小于0的都被分类成了0,表示我们这次修改取得了成功。



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有