【tensorflow】将训练数据转为tfrecord 您所在的位置:网站首页 csv转tfrecord 【tensorflow】将训练数据转为tfrecord

【tensorflow】将训练数据转为tfrecord

2024-07-16 01:30| 来源: 网络整理| 查看: 265

目标检测训练数据的一般包括图像和对应的标注xml文件,这里以四边形标注目标,如下:

转换为tfrecord文件 def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def read_xml_gtbox_and_label(xml_path): """ :param xml_path: the path of voc xml :return: a list contains gtboxes and labels, shape is [num_of_gtboxes, 5], and has [xmin, ymin, xmax, ymax, label] in a per row """ tree = ET.parse(xml_path) root = tree.getroot() img_width = None img_height = None box_list = [] for child_of_root in root: # if child_of_root.tag == 'filename': # assert child_of_root.text == xml_path.split('/')[-1].split('.')[0] \ # + FLAGS.img_format, 'xml_name and img_name cannot match' if child_of_root.tag == 'size': for child_item in child_of_root: if child_item.tag == 'width': img_width = int(child_item.text) if child_item.tag == 'height': img_height = int(child_item.text) if child_of_root.tag == 'object': label = None for child_item in child_of_root: if child_item.tag == 'name': category = child_item.text.encode("utf-8") #如果xml文件中目标的类别是中文,那么就需要对child_item.text进行‘utf-8’编码转换为str格式(child_item.text是Unicode格式) #category = child_item.text #如果xml文件中目标的类别是英文 label = NAME_LABEL_MAP[category] if child_item.tag == 'bndbox': tmp_box = [] for node in child_item: tmp_box.append(int(node.text)) # [x1, y1. x2, y2, x3, y3, x4, y4] assert label is not None, 'label is none, error' tmp_box.append(label) # [x1, y1. x2, y2, x3, y3, x4, y4, label] box_list.append(tmp_box) gtbox_label = np.array(box_list, dtype=np.int32) # [x1, y1. x2, y2, x3, y3, x4, y4, label] return img_height, img_width, gtbox_label def convert_pascal_to_tfrecord(): ''' 每一张样本图片可以看做是一个example,每个Example中包含features features里包含feature(这里没s)的字典,feature分为FloatList,或ByteList,或Int64List 的格式 例如该例子中,首先利用tf.train.Features函数来创建每一个样本的features features中包括样本的名称(img_name)、高度(img_height)等字典信息,这些字典信息要利用tf.train.Feature函数创建 例如样本名称是二进制的格式,因此在开头创建了_bytes_feature函数,其中调用tf.train.Feature函数,并设置为bytes_list 高度是int64形式,因此创建了_int64_feature函数,其中调用tf.train.Feature函数,并设置为Int64List 最后再利用tf.train.Example函数,将上述的features赋给Example ''' xml_path = FLAGS.VOC_dir + FLAGS.xml_dir image_path = FLAGS.VOC_dir + FLAGS.image_dir save_path = FLAGS.save_dir + FLAGS.dataset + '_' + FLAGS.save_name + '.tfrecord' #os.mkdir(FLAGS.save_dir) #writer_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) #定义了tfrecords文件压缩类型:无,ZLIB,GZIP三种 #writer = tf.python_io.TFRecordWriter(path=save_path, options=writer_options) #建立TFRecord存储器,path是TFRecords文件的路径 writer = tf.python_io.TFRecordWriter(path=save_path) #可以用该行代码代替前两个 for count, xml in enumerate(glob.glob(xml_path + '/*.xml')): # to avoid path error in different development platform xml = xml.replace('\\', '/') img_name = xml.split('/')[-1].split('.')[0] + FLAGS.img_format img_path = image_path + '/' + img_name if not os.path.exists(img_path): print('{} is not exist!'.format(img_path)) continue img_height, img_width, gtbox_label = read_xml_gtbox_and_label(xml) # img = np.array(Image.open(img_path)) img = cv2.imread(img_path) feature = tf.train.Features(feature={ # maybe do not need encode() in linux # 'img_name': _bytes_feature(img_name.encode()), 'img_name': _bytes_feature(img_name), 'img_height': _int64_feature(img_height), 'img_width': _int64_feature(img_width), 'img': _bytes_feature(img.tostring()), # 'gtboxes_and_label': _bytes_feature(gtbox_label.tostring()), 'num_objects': _int64_feature(gtbox_label.shape[0]) }) example = tf.train.Example(features=feature) writer.write(example.SerializeToString()) #把example序列化为一个字符串,因为在写入到TFRcorde的时候,write方法的参数是字符串 view_bar('Conversion progress', count + 1, len(glob.glob(xml_path + '/*.xml'))) print('\nConversion is complete!') 检查tfrecord文件是否有问题 import os import tensorflow as tf import sys stdi, stdo, stde = sys.stdin, sys.stdout, sys.stderr reload(sys) sys.setdefaultencoding('utf-8') sys.stdin, sys.stdout, sys.stderr = stdi, stdo, stde def read_single_example_and_decode(filename_queue): #如果你在上面转换的代码中采用了前面两行,那么相应的就采用下面这两行 #tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) #reader = tf.TFRecordReader(options=tfrecord_options) #构造阅读器 #否则采用: reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 #解析协议块,返回的值是字典 features = tf.parse_single_example( serialized=serialized_example, features={ 'img_name': tf.FixedLenFeature([], tf.string), 'img_height': tf.FixedLenFeature([], tf.int64), 'img_width': tf.FixedLenFeature([], tf.int64), 'img': tf.FixedLenFeature([], tf.string), 'gtboxes_and_label': tf.FixedLenFeature([], tf.string), 'num_objects': tf.FixedLenFeature([], tf.int64) } ) img_name = features['img_name'] img_height = tf.cast(features['img_height'], tf.int32) #将数据类型int64 转换为int32 img_width = tf.cast(features['img_width'], tf.int32) #将数据类型int64 转换为int32 img = tf.decode_raw(features['img'], tf.uint8) ##如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型,decode_raw()可以将数据从string,bytes转换为int,float类型的 img = tf.reshape(img, shape=[img_height, img_width, 3]) ##转换图片的形状,此处需要用动态形状进行转换 gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32) gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9]) num_objects = tf.cast(features['num_objects'], tf.int32) return img_name, img, gtboxes_and_label, num_objects directory = os.path.join('/home/yantianwang/rdfpn/data/tfrecord', 'hangtian_ship_train.tfrecord') if not os.path.exists(directory): print('不存在') filename_tensorlist = tf.train.match_filenames_once(directory) # 获取文件列表 filename_queue = tf.train.string_input_producer(filename_tensorlist)# 创建文件输入队列 img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue) #解析数据 img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = tf.train.batch( [img_name, img, gtboxes_and_label, num_objects], batch_size = 1, capacity=100, num_threads=16, dynamic_pad=True) init = (tf.global_variables_initializer(), tf.local_variables_initializer()) with tf.Session() as sess: sess.run(init) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) for step in range(10000): print(step,sess.run(img_name_batch)) coord.request_stop() coord.join(threads)

如果测试的代码运行无误,那么就说明tfrecord文件没有问题。如果出现问题:

PaddingFIFOQueue '_2_batch/padding_fifo_queue' is closed and has insufficient elements (requested 1, current size 0)

 那么和那程度上说明的你准备的数据有问题,需要检查一下样本和相应的xml文件有无问题,比如xml文件中记录的图像长宽与图像不一致、目标的标注超过了图像的范围等等....

读取tfrecord文件生成batch import tensorflow as tf import os from data.io import image_preprocess def read_single_example_and_decode(filename_queue): #tfrecord_options = tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB) #reader = tf.TFRecordReader(options=tfrecord_options) #构造阅读器 reader = tf.TFRecordReader() _, serialized_example = reader.read(filename_queue) #返回文件名和文件 #解析协议块,返回的值是字典 features = tf.parse_single_example( serialized=serialized_example, features={ 'img_name': tf.FixedLenFeature([], tf.string), 'img_height': tf.FixedLenFeature([], tf.int64), 'img_width': tf.FixedLenFeature([], tf.int64), 'img': tf.FixedLenFeature([], tf.string), 'gtboxes_and_label': tf.FixedLenFeature([], tf.string), 'num_objects': tf.FixedLenFeature([], tf.int64) } ) img_name = features['img_name'] img_height = tf.cast(features['img_height'], tf.int32) #将数据类型int64 转换为int32 img_width = tf.cast(features['img_width'], tf.int32) #将数据类型int64 转换为int32 img = tf.decode_raw(features['img'], tf.uint8) ##如果之前用了tostring(),那么必须要用decode_raw()转换为最初的int类型,decode_raw()可以将数据从string,bytes转换为int,float类型的 img = tf.reshape(img, shape=[img_height, img_width, 3]) ##转换图片的形状,此处需要用动态形状进行转换 gtboxes_and_label = tf.decode_raw(features['gtboxes_and_label'], tf.int32) gtboxes_and_label = tf.reshape(gtboxes_and_label, [-1, 9]) num_objects = tf.cast(features['num_objects'], tf.int32) return img_name, img, gtboxes_and_label, num_objects def read_and_prepocess_single_img(filename_queue, shortside_len, is_training): img_name, img, gtboxes_and_label, num_objects = read_single_example_and_decode(filename_queue) # img = tf.image.per_image_standardization(img) img = tf.cast(img, tf.float32) img = img - tf.constant([103.939, 116.779, 123.68]) if is_training: img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label, target_shortside_len=shortside_len) img, gtboxes_and_label = image_preprocess.random_flip_left_right(img_tensor=img, gtboxes_and_label=gtboxes_and_label) else: img, gtboxes_and_label = image_preprocess.short_side_resize(img_tensor=img, gtboxes_and_label=gtboxes_and_label, target_shortside_len=shortside_len) return img_name, img, gtboxes_and_label, num_objects def next_batch(dataset_name, batch_size, shortside_len, is_training): if dataset_name not in ['ship', 'spacenet', 'pascal', 'coco','hangtian_ship']: #增加自己的数据库名称 raise ValueError('dataSet name must be in pascal or coco') if is_training: pattern = os.path.join('../data/tfrecord', dataset_name + '_train*') else: pattern = os.path.join('../data/tfrecord', dataset_name + '_test*') print('tfrecord path is -->', os.path.abspath(pattern)) filename_tensorlist = tf.train.match_filenames_once(pattern) filename_queue = tf.train.string_input_producer(filename_tensorlist) img_name, img, gtboxes_and_label, num_obs = read_and_prepocess_single_img(filename_queue, shortside_len, is_training=is_training) img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch = \ tf.train.batch( [img_name, img, gtboxes_and_label, num_obs], batch_size=batch_size, capacity=100, num_threads=16, dynamic_pad=True) return img_name_batch, img_batch, gtboxes_and_label_batch, num_obs_batch

该部分代码包括了对数据的处理



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有