结合ModelArts打造基于Yolov5的垃圾分类识别模型 | 您所在的位置:网站首页 › yolov5垃圾分类数据集 › 结合ModelArts打造基于Yolov5的垃圾分类识别模型 |
2020年6月参加了华为云AI算法大赛--垃圾分类识别大赛,虽然没有进入决赛,但是还是想将ModelArts在构建的模型经验分享出来,现在的开发的主流思想就是开源为王。yolov5是yolo系列的最新版本的目标检测模型,较之前的yolov3、yolov4的版本有很大的提升,官方给出的效果提升如下: 一、数据预处理本次采用的数据集是官方提供的VOC2007格式的垃圾分类数据集。 由于yolov5需要的训练数据格式如下: 每个图像的标签文件应该可以通过在其路径名中简单地替换/images/*.jpg为来定位/labels/*.txt。示例图像和标签对为: 因此数据处理脚本如下: # 运行成功后会生成如下目录结构的文件夹: # trainval/ # -images # -0001.jpg # -0002.jpg # -0003.jpg # -labels # -0001.txt # -0002.txt # -0003.txt # 将trainval文件夹打包并命名为trainval.zip, 上传到OBS中以备使用。 import os import codecs import xml.etree.ElementTree as ET from tqdm import tqdm import shutil import argparse def get_classes(classes_path): '''loads the classes''' with codecs.open(classes_path, 'r', 'utf-8') as f: class_names = f.readlines() class_names = [c.strip() for c in class_names] return class_names def convert(size, box): w = size[0] h = size[1] w_center = round((box[2] - box[0]) / w, 6) h_center = round((box[3] - box[1]) / h, 6) x_center = round(((box[2] + box[0]) / 2) / w, 6) y_center = round(((box[3] + box[1]) / 2) / h, 6) # dw = 1./(size[0]) # dh = 1./(size[1]) # x = (box[0] + box[2])/2.0 -1 # y = (box[1] + box[3])/2.0 -1 # w = box[2] - box[0] # h = box[3] - box[1] # x = x*dw # w = w*dw # y = y*dh # h = h*dh return x_center, y_center, w_center, h_center #return x_center, y_center, w_center, h def creat_label_txt(soure_datasets, new_datasets): annotations = os.path.join(soure_datasets, 'VOC2007/Annotations') txt_path = os.path.join(new_datasets, 'labels') class_names = get_classes(os.path.join(soure_datasets, 'train_classes.txt')) xmls = os.listdir(annotations) for xml in tqdm(xmls): txt_anno_path = os.path.join(txt_path, xml.replace('xml', 'txt')) xml = os.path.join(annotations, xml) tree = ET.parse(xml) root = tree.getroot() size = root.find('size') w = int(size.find('width').text) h = int(size.find('height').text) line = '' for obj in root.iter('object'): cls = obj.find('name').text if cls not in class_names: print('name error', xml) continue cls_id = class_names.index(cls) xmlbox = obj.find('bndbox') box = [int(xmlbox.find('xmin').text), int(xmlbox.find('ymin').text), int(xmlbox.find('xmax').text), int(xmlbox.find('ymax').text)] if box[2] > w or box[3] > h: print('Image with annotation error:', xml) if box[0] |
CopyRight 2018-2019 实验室设备网 版权所有 |