Tensorflow(3):创建画板,实时在线手写体识别 您所在的位置:网站首页 在线手写体识别 Tensorflow(3):创建画板,实时在线手写体识别

Tensorflow(3):创建画板,实时在线手写体识别

2022-05-07 12:17| 来源: 网络整理| 查看: 265

   学习Tensorflow已经有一段时间了,就想能不能实现手写体的实时在线识别,于是进行了一番探索。本文源代码可以在这里下载【Python3+PyQt5+Tensorflow】创建画板,实时在线手写体识别】    用到的库:Python3.6.5 + PyQt5 + PIL,编写GUI程序,实现手写体实时在线识别。最终实现的效果如下图所示,在方框里用鼠标手写数字,左下角显示识别结果,准确率可以达到99.2%。

1.画板GUI及模型加载(MyMnistWindow.py)

  使用PyQt5制作了一个交互式画板,可以用鼠标在上面写字。画板的程序部分参考了【Python3使用PyQt5制作简单的画板/手写板】。

''' 功能: 利用训练好的模型,进行实时手写体识别 作者:yuhansgg 博客: https://blog.csdn.net/u011389706 日期: 2018/08/06 ''' import tensorflow as tf from PyQt5.QtWidgets import (QWidget, QPushButton, QLabel) from PyQt5.QtGui import (QPainter, QPen, QFont) from PyQt5.QtCore import Qt from PIL import ImageGrab, Image class MyMnistWindow(QWidget): def __init__(self): super(MyMnistWindow, self).__init__() self.resize(284, 330) # resize设置宽高 self.move(100, 100) # move设置位置 self.setWindowFlags(Qt.FramelessWindowHint) # 窗体无边框 #setMouseTracking设置为False,否则不按下鼠标时也会跟踪鼠标事件 self.setMouseTracking(False) self.pos_xy = [] #保存鼠标移动过的点 # 添加一系列控件 self.label_draw = QLabel('', self) self.label_draw.setGeometry(2, 2, 280, 280) self.label_draw.setStyleSheet("QLabel{border:1px solid black;}") self.label_draw.setAlignment(Qt.AlignCenter) self.label_result_name = QLabel('识别结果:', self) self.label_result_name.setGeometry(2, 290, 60, 35) self.label_result_name.setAlignment(Qt.AlignCenter) self.label_result = QLabel(' ', self) self.label_result.setGeometry(64, 290, 35, 35) self.label_result.setFont(QFont("Roman times", 8, QFont.Bold)) self.label_result.setStyleSheet("QLabel{border:1px solid black;}") self.label_result.setAlignment(Qt.AlignCenter) self.btn_recognize = QPushButton("识别", self) self.btn_recognize.setGeometry(110, 290, 50, 35) self.btn_recognize.clicked.connect(self.btn_recognize_on_clicked) self.btn_clear = QPushButton("清空", self) self.btn_clear.setGeometry(170, 290, 50, 35) self.btn_clear.clicked.connect(self.btn_clear_on_clicked) self.btn_close = QPushButton("关闭", self) self.btn_close.setGeometry(230, 290, 50, 35) self.btn_close.clicked.connect(self.btn_close_on_clicked) def paintEvent(self, event): painter = QPainter() painter.begin(self) pen = QPen(Qt.black, 30, Qt.SolidLine) painter.setPen(pen) if len(self.pos_xy) > 1: point_start = self.pos_xy[0] for pos_tmp in self.pos_xy: point_end = pos_tmp if point_end == (-1, -1): point_start = (-1, -1) continue if point_start == (-1, -1): point_start = point_end continue painter.drawLine(point_start[0], point_start[1], point_end[0], point_end[1]) point_start = point_end painter.end() def mouseMoveEvent(self, event): ''' 按住鼠标移动事件:将当前点添加到pos_xy列表中 ''' #中间变量pos_tmp提取当前点 pos_tmp = (event.pos().x(), event.pos().y()) #pos_tmp添加到self.pos_xy中 self.pos_xy.append(pos_tmp) self.update() def mouseReleaseEvent(self, event): ''' 重写鼠标按住后松开的事件 在每次松开后向pos_xy列表中添加一个断点(-1, -1) ''' pos_test = (-1, -1) self.pos_xy.append(pos_test) self.update() def btn_recognize_on_clicked(self): bbox = (104, 104, 380, 380) im = ImageGrab.grab(bbox) # 截屏,手写数字部分 im = im.resize((28, 28), Image.ANTIALIAS) # 将截图转换成 28 * 28 像素 recognize_result = self.recognize_img(im) # 调用识别函数 self.label_result.setText(str(recognize_result)) # 显示识别结果 self.update() def btn_clear_on_clicked(self): self.pos_xy = [] self.label_result.setText('') self.update() def btn_close_on_clicked(self): self.close() def recognize_img(self, img): # 手写体识别函数 myimage = img.convert('L') # 转换成灰度图 tv = list(myimage.getdata()) # 获取图片像素值 tva = [(255 - x) * 1.0 / 255.0 for x in tv] # 转换像素范围到[0 1], 0是纯白 1是纯黑 init = tf.global_variables_initializer() saver = tf.train.Saver with tf.Session() as sess: sess.run(init) saver = tf.train.import_meta_graph('minst_cnn_model.ckpta') # 载入模型结构 saver.restore(sess, 'minst_cnn_model.ckpt') # 载入模型参数 graph = tf.get_default_graph() # 加载计算图 x = graph.get_tensor_by_name("x:0") # 从模型中读取占位符变量 keep_prob = graph.get_tensor_by_name("keep_prob:0") y_conv = graph.get_tensor_by_name("y_conv:0") # 关键的一句 从模型中读取占位符变量 prediction = tf.argmax(y_conv, 1) predint = prediction.eval(feed_dict={x: [tva], keep_prob: 1.0}, session=sess) # feed_dict输入数据给placeholder占位符 print(predint[0]) return predint[0]

   识别时,先利用函数ImageGrab.grab(bbox),对屏幕画板部分进行截图。然后对截图进行预处理(缩放到28*28像素,转换成灰度图等)。    在最后,最重要的手写体识别函数里recognize_img(self, img),我们调用了已经训练好的模型minst_cnn_model.ckpt,具体模型训练过程,参见【Tensorflow(2):MNIST识别自己手写的数字–进阶篇(CNN)】

2.主程序(main.py)

   实例化我们上面定义的窗体类MyMnistWindow,实现窗体显示。

import sys from PyQt5.QtWidgets import QApplication from MyMnistWindow import MyMnistWindow if __name__ == "__main__": app = QApplication(sys.argv) mymnist = MyMnistWindow() mymnist.show() app.exec_() 3.实验结果

   最终测试结果如下所示,左下角显示识别结果。可以看到,基本都能正确识别:

4.注意事项

    若识别准确率不高,一般是由于我们手写数字和训练数据相差太大导致的。     一般原因是:训练集是西方的手写数字,和中国的手写数字习惯不同。下面是官方的训练数据中的部分数字。

    在画图时,笔法尽量和上面训练集保持一致,就会得到较高的识别率!。     本文源代码可以在这里下载【Python3+PyQt5+Tensorflow】创建画板,实时在线手写体识别】     是以为记!



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有