【学习笔记】VGG 网络结构

您所在的位置:网站首页 卷积神经网络的基本结构的图表 【学习笔记】VGG 网络结构

【学习笔记】VGG 网络结构

2024-07-10 12:06:38| 来源: 网络整理| 查看: 265

跟着大佬学图像分类系列,→ 传送门 ← 本博客图像分类系列文章传送门:

AlexNet VGG(当前) GoogleNet ResNet 前言

图像分类是学习目标检测的“量变”内容,那么,废话不多说,开搞!

一、VGG 是什么?

        VGG 网络是14年被牛津大学的著名研究组 VGG(Visual Geometry Group)提出,斩获该年 ImageNet 竞赛中 Localization Task(定位任务)第一名和 Classification Task(分类任务)第二名。

二、网络结构 (VGG网络论文中提供的6种网络配置) 1.网络特点 通过堆叠多个 3*3 的卷积核来代替大尺度卷积核(减少所需的参数)     论文提出可以通过堆叠两个 3*3 的卷积核来代替 5*5 的卷积核;堆叠三个 3*3 的卷积核来代替 7*7 的卷积核。虽然用了小的卷积核来替换大的卷积核,但并不会影响感受野,即感受野是相同的。 2.感受野(拓展)

        在卷积神经网络中,决定某一层输出结果中一个元素所对应的输入层的区域大小,被称为感受野。通俗来说就是输出的 feature map 上的一个单元对应输入层上的区域大小。

        如上图所示,自下向上,输入一个 9*9*1 的特征图,经过卷积层 Conv1,得到 4*4*1 大小的第一个输出层,再经过池化层 MaxPool1,得到 2*2*1 大小的输出层,。那么第二个输出层的一个单元(绿色方块)的感受野就是 2*2 大小的区域;第一层输出层的一个单元(蓝色方块)的感受野就是 5*5 大小的区域。

感受野计算公式: F ( i ) F(i) F(i) = ( F ( i + 1 ) − 1 ) ∗ S t r i d e + K s i z e ( F(i + 1) - 1 )*Stride + Ksize (F(i+1)−1)∗Stride+Ksize 式中, F ( i ) F(i) F(i) 为第 i 层感受野,Stride 为第 i 层的步距,Ksize 为卷积核或池化核的尺寸 以上图为例:     Feature Map: F ( 3 ) = 1 F(3) = 1 F(3)=1(因为上面没有了,所以是1个单元格)     Pool1:   F ( 2 ) F(2) F(2) = ( F ( 3 ) − 1 ) ∗ 2 + 2 = 2 (F(3) - 1 )*2+ 2 = 2 (F(3)−1)∗2+2=2     Conv1: F ( 1 ) F(1) F(1) = ( F ( 2 ) − 1 ) ∗ 2 + 3 = 5 ( F(2) - 1 )*2+ 3 = 5 (F(2)−1)∗2+3=5

那么放在 VGG中就是:     Feature Map: F = 1 F = 1 F=1(顶层)     Conv3:   F F F = ( 1 − 1 ) ∗ 1 + 3 = 3 (1 - 1 )*1+ 3 = 3 (1−1)∗1+3=3(VGG的卷积核默认步长为1,大小为 3*3)     Conv2: F ( 1 ) F(1) F(1) = ( 3 − 1 ) ∗ 1 + 3 = 5 (3 - 1 )*1+ 3 = 5 (3−1)∗1+3=5 (所以堆叠两层卷积核,感受野与一个 5*5 大小的卷积核是一样的)     Conv1: F ( 1 ) F(1) F(1) = ( 5 − 1 ) ∗ 1 + 3 = 7 (5 - 1 )*1+ 3 = 7 (5−1)∗1+3=7 (堆叠三层卷积核,感受野与一个 7*7 大小的卷积核是一样的)

3.结构 在前面提到的“VGG网络论文中提供的6种网络配置”中,配置D是常用的结构(VGG16),因此这里也主要分析 VGG16 的结构。(该结构中使用的所有卷积核步长均为1,padding 均为1;池化核大小均为2,步长为2) 在这里插入图片描述 numberInput_sizeoutput_sizekernelskernels_sizeConv1[224, 224, 3][224, 224, 64]643Conv2[224, 224, 64][224, 224, 64]643MaxPooling1[224, 224, 64][112, 112, 64]\2Conv3[112, 112, 64][112, 112, 128]1283Conv4[112, 112, 128][112, 112, 128]1283MaxPooling2[112, 112, 128][56, 56, 128]\2Conv5[56, 56, 128][56, 56, 256]2563Conv6[56, 56, 256][56, 56, 256]2563Conv7[56, 56, 256][56, 56, 256]2563MaxPooling3[56, 56, 256][28, 28, 256]\2Conv8[28, 28, 256][28, 28, 512]5123Conv9[28, 28, 512][28, 28, 512]5123Conv10[28, 28, 512][28, 28, 512]5123MaxPooling4[28, 28, 512][14, 14, 512]\2Conv11[14, 14, 512][14, 14, 512]5123Conv12[14, 14, 512][14, 14, 512]5123Conv13[14, 14, 512][14, 14, 512]5123MaxPooling5[28, 28, 512][7, 7, 512]\2FC17*7*512(展平)\\4096FC24096\\4096FC34096\\1000 三、使用 Pytorch 搭建 VGG 网络

本代码使用的数据集来自 “花分类” 数据集,→ 传送门 ←(具体内容看 data_set文件夹下的 README.md)

model.py ( 搭建 VGG 网络模型 ) import torch.nn as nn import torch class VGG(nn.Module): def __init__(self, features, class_num=1000, init_weight=False): super(VGG, self).__init__() # 卷积层和池化层,来自 make_features 生成的特征提取网络 self.features = features # 三层全连接层 self.classifier = nn.Sequential( nn.Dropout(p=0.5), nn.Linear(512*7*7, 2048), nn.ReLU(True), nn.Dropout(p=0.5), nn.Linear(2048, 2048), nn.ReLU(True), nn.Linear(2048, class_num) ) if init_weight: self._initialize_weight() # 详见 AlexNet 学习笔记 def forward(self, x): x = self.features(x) x = torch.flatten(x, start_dim=1) # 展平,进入全连接层 x = self.classifier(x) return x def _initialize_weight(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0) """ VGG网络几种不同的卷积网络配置(A,B,D,E) """ configs = { # A 数字代表卷积核的数量,'M' 表示池化层 'vgg11':[64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # B 'vgg13':[64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], # D 'vgg16':[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], # E 'vgg19':[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'] } # 根据选择的网络配置,生成提取特征网络结构 def make_features(cfg: list): layers = [] in_channels = 3 # 初始输入通道(即 RGB 3通道) for v in cfg: if v == 'M': # 数组第i个元素为M,表示需要创建池化层 layers += [nn.MaxPool2d(kernel_size=2, stride=2)] # 池化核固定大小为2,步长为2 else: # 元素不为M,表示需要创建卷积层 conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) # 卷积核固定大小为3,步长为1,padding为1 layers += [conv2d, nn.ReLU(True)] # 卷积层后面会进入激活函数,这里当做一个整体放入一层 in_channels = v # 通道数(深度)变为卷积核的数量 return nn.Sequential(*layers) # *表示非关键字传入参数(Sequential见AlexNet网络学习笔记) # 默认使用 Vgg16,用户可通过传参改变网络配置 def vgg(model_name="vgg16", **kwargs): # **kwargs:可变长度字典 try: cfg = configs[model_name] except: print("Warning: Model number {} not in configs dict!".format(model_name)) exit(-1) model = VGG(make_features(cfg), **kwargs) return model train.py ( 训练网络 ) import os import json import torch import torch.nn as nn from torchvision import transforms, datasets import torch.optim as optim from tqdm import tqdm from model import vgg def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print("using {} device.".format(device)) # 数据预处理 data_transform = { "train": transforms.Compose([transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]), "val": transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])} # 获取数据集 data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path assert os.path.exists(image_path), "{} path does not exist.".format(image_path) train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"), transform=data_transform["train"]) train_num = len(train_dataset) # {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4} flower_list = train_dataset.class_to_idx cla_dict = dict((val, key) for key, val in flower_list.items()) # write dict into json file json_str = json.dumps(cla_dict, indent=4) with open('class_indices.json', 'w') as json_file: json_file.write(json_str) batch_size = 32 nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers print('Using {} dataLoader workers every process'.format(nw)) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw) validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"), transform=data_transform["val"]) val_num = len(validate_dataset) validate_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size, shuffle=False, num_workers=nw) print("using {} images for training, {} images for validation.".format(train_num, val_num)) # test_data_iter = iter(validate_loader) # test_image, test_label = test_data_iter.next() model_name = "vgg16" net = vgg(model_name=model_name, num_classes=5, init_weights=True) net.to(device) loss_function = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.0001) epochs = 30 best_acc = 0.0 save_path = './{}Net.pth'.format(model_name) train_steps = len(train_loader) for epoch in range(epochs): # train net.train() running_loss = 0.0 train_bar = tqdm(train_loader) for step, data in enumerate(train_bar): images, labels = data optimizer.zero_grad() outputs = net(images.to(device)) loss = loss_function(outputs, labels.to(device)) loss.backward() optimizer.step() # print statistics running_loss += loss.item() train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1, epochs, loss) # validate net.eval() acc = 0.0 # accumulate accurate number / epoch with torch.no_grad(): val_bar = tqdm(validate_loader) for val_data in val_bar: val_images, val_labels = val_data outputs = net(val_images.to(device)) predict_y = torch.max(outputs, dim=1)[1] acc += torch.eq(predict_y, val_labels.to(device)).sum().item() val_accurate = acc / val_num print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' % (epoch + 1, running_loss / train_steps, val_accurate)) if val_accurate > best_acc: best_acc = val_accurate torch.save(net.state_dict(), save_path) print('Finished Training') if __name__ == '__main__': main() predict.py ( 使用训练好的模型网络对图像分类 ) import os import json import torch from PIL import Image from torchvision import transforms import matplotlib.pyplot as plt from model import vgg def main(): device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") data_transform = transforms.Compose( [transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # load image img_path = "../tulip.jpg" assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path) img = Image.open(img_path) plt.imshow(img) # [N, C, H, W] img = data_transform(img) # expand batch dimension img = torch.unsqueeze(img, dim=0) # read class_indict json_path = './class_indices.json' assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path) json_file = open(json_path, "r") class_indict = json.load(json_file) # create model model = vgg(model_name="vgg16", num_classes=5).to(device) # load model weights weights_path = "./vgg16Net.pth" assert os.path.exists(weights_path), "file: '{}' dose not exist.".format(weights_path) model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() with torch.no_grad(): # predict class output = torch.squeeze(model(img.to(device))).cpu() predict = torch.softmax(output, dim=0) predict_cla = torch.argmax(predict).numpy() print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)], predict[predict_cla].numpy()) plt.title(print_res) print(print_res) plt.show() if __name__ == '__main__': main()

代码连接 https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/Test3_vggnet



【本文地址】

公司简介

联系我们

今日新闻


点击排行

实验室常用的仪器、试剂和
说到实验室常用到的东西,主要就分为仪器、试剂和耗
不用再找了,全球10大实验
01、赛默飞世尔科技(热电)Thermo Fisher Scientif
三代水柜的量产巅峰T-72坦
作者:寞寒最近,西边闹腾挺大,本来小寞以为忙完这
通风柜跟实验室通风系统有
说到通风柜跟实验室通风,不少人都纠结二者到底是不
集消毒杀菌、烘干收纳为一
厨房是家里细菌较多的地方,潮湿的环境、没有完全密
实验室设备之全钢实验台如
全钢实验台是实验室家具中较为重要的家具之一,很多

推荐新闻


图片新闻

实验室药品柜的特性有哪些
实验室药品柜是实验室家具的重要组成部分之一,主要
小学科学实验中有哪些教学
计算机 计算器 一般 打孔器 打气筒 仪器车 显微镜
实验室各种仪器原理动图讲
1.紫外分光光谱UV分析原理:吸收紫外光能量,引起分
高中化学常见仪器及实验装
1、可加热仪器:2、计量仪器:(1)仪器A的名称:量
微生物操作主要设备和器具
今天盘点一下微生物操作主要设备和器具,别嫌我啰嗦
浅谈通风柜使用基本常识
 众所周知,通风柜功能中最主要的就是排气功能。在

专题文章

    CopyRight 2018-2019 实验室设备网 版权所有 win10的实时保护怎么永久关闭