深度学习12 您所在的位置:网站首页 matlab导入数据集代码 深度学习12

深度学习12

2023-06-10 17:08| 来源: 网络整理| 查看: 265

目录

VGG19实现

1.为数据打标签的generate_txt.py

2.对图像进行预处理的data_process.py

3.VGG19的网络构建代码net_VGG19.py

4.训练得到pth模型参数文件的get_pth_file.py

5.预测代码predict.py

6.预测VGG16与VGG19结果对比

VGG19实现 1.为数据打标签的generate_txt.py

这里的程序设计思想还是可以学一下的。

import os from os import getcwd # 文件夹操作 # 写入数据集对应的文件夹 classes = ['cat','dog'] sets = ['train'] # 主程序执行 if __name__ == "__main__": wd = getcwd() # 获取当前工作目录 # 说明:当前代码的目录关系是:sets-->subset(当前只有一个“train”)-->type_name(“train”下的关于类别的文件夹,比如“dog”)-->具体样本数据 # 遍历sets中的每个文件夹,当前sets中只有"train"一个文件夹 for subset in sets: list_file = open("cls_"+subset+".txt",'w') # 拿到每个子文件夹的目录 path_subset = subset type_names = os.listdir(path_subset) # 拿到subset文件夹下的所有动物分类文件夹type_name,存到type_names列表中 # 遍历subset中的每个文件夹type_name """ 它遍历名为type_names的列表中的每个元素。 代码的目的是检查每个元素是否存在于名为classes的集合中。 如果存在,代码会继续执行下一次循环,处理下一个元素; 如果不存在,代码会跳过当前循环并继续执行下一次循环。 """ for type_name in type_names: if type_name not in classes: continue # 打标签 type_id = classes.index(type_name) # 按type_name文件夹在classes文件夹中的索引,为type_name编号 # 生成每个type_name文件夹的路径 type_path = os.path.join(path_subset,type_name) photo_names = os.listdir(type_path) # 拿到type_name文件夹下的所有图片,组成一个列表phto_names # 处理每一张图片 """ 这段代码的作用是遍历名为photos_name的列表中的每个元素, 并根据文件名的扩展名来过滤文件。代码会判断文件的扩展名是否为.jpg、.png或.jpeg, 如果不是这些扩展名之一,则跳过当前文件的处理。对于符合条件的文件, 代码会将其写入到名为list_file的文件中,并写入文件的类别ID和路径信息。 """ for photo_name in photo_names: """ 这一行代码使用os.path.splitext()函数将文件名photo_name分成文件名部分和扩展名部分, 并将扩展名赋值给变量postfix。下划线_表示不使用文件名部分,只关注扩展名。 """ _,postfit = os.path.splitext(photo_name) # #该函数用于分离文件名与拓展名 # 如果拓展名不在如下的列表中,则跳过当前循环;如果在,则继续 if postfit not in ['.jpg','.png','.jpeg']: continue # 将文件的类别ID和完整路径信息写入到名为list_file的文件中 photo_path = os.path.join(type_path,photo_name) # print(wd) # C:\Users\ZARD\PycharmProjects\pythonProject\AAA_FX\revise_VGG19 # 如上可知,wd为该项目的路径 list_file.write(str(type_id)+';'+'%s/%s'%(wd,photo_path)) list_file.write('\n') # 这一行代码写入一个换行符,将下一个文件的记录写入到新的一行 list_file.close() 2.对图像进行预处理的data_process.py

对数据做一些基本操作,可根据实际需求进行更改。

import cv2 import numpy as np import torch.utils.data as data from PIL import Image def preprocess_input(x): x/=127.5 x-=1. return x def cvtColor(image): if len(np.shape(image))==3 and np.shape(image)[-2]==3: return image else: image=image.convert('RGB') return image class DataGenerator(data.Dataset): def __init__(self,annotation_lines,inpt_shape,random=True): self.annotation_lines=annotation_lines self.input_shape=inpt_shape self.random=random def __len__(self): return len(self.annotation_lines) def __getitem__(self, index): annotation_path=self.annotation_lines[index].split(';')[1].split()[0] image=Image.open(annotation_path) image=self.get_random_data(image,self.input_shape,random=self.random) image=np.transpose(preprocess_input(np.array(image).astype(np.float32)),[2,0,1]) y=int(self.annotation_lines[index].split(';')[0]) return image,y def rand(self,a=0,b=1): return np.random.rand()*(b-a)+a def get_random_data(self,image,inpt_shape,jitter=.3,hue=.1,sat=1.5,val=1.5,random=True): image=cvtColor(image) iw,ih=image.size h,w=inpt_shape if not random: scale=min(w/iw,h/ih) nw=int(iw*scale) nh=int(ih*scale) dx=(w-nw)//2 dy=(h-nh)//2 image=image.resize((nw,nh),Image.BICUBIC) new_image=Image.new('RGB',(w,h),(128,128,128)) new_image.paste(image,(dx,dy)) image_data=np.array(new_image,np.float32) return image_data new_ar=w/h*self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter) scale=self.rand(.75,1.25) if new_ar


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有