VGG16网络的参数和FLOPs计算 您所在的位置:网站首页 模型的参数怎么算 VGG16网络的参数和FLOPs计算

VGG16网络的参数和FLOPs计算

2024-02-29 15:06| 来源: 网络整理| 查看: 265

VGG16网络的参数和FLOPs计算

1、前言

VGG16是一种卷积神经网络(Convolutional Neural Network,CNN),是由Simonyan和Zisserman于2014年提出的用于图像识别任务的其中一种网络结构。VGG16网络包含13个卷积层和3个全连接层,在ImageNet数据集上表现出色,并成为后来CNN结构的基础。下面对VGG16网络的参数和FLOPs进行计算。

在这里插入图片描述

2、计算 1、模型的参数量

1.1、卷积层计算 模型的参数量=[卷积核的长∗卷积核的宽∗卷积核的高(即通道,由上一层的输出通道决定)]∗卷积核的数量+偏置参数(其等于卷积核的数量)以第一个卷积层为例Conv1_1=kernel(Height)*kernel(Width)InPut(Channel)OutPut(Channel)=(333)*64=1728

1.2、全连接层 模型的参数量=上一层输入的长宽高(即通道)本层的长宽*高(即通道)以第一个全连接层为例: FC1[(LayerID21,Output(Height))∗(LayerID21,Output(Width))∗(LayerID21,Output(Channel))]∗[(LayerID22,Output(Height))∗(LayerID22,Output(Width))∗(LayerID22,Output(Channel))]=[(7∗7∗512)]∗[1∗1∗4096]=102,760,448

2、模型FLOPs

2.1、卷积层计算 FLOPS数量=参数量∗该层输出特征图的大小该层输出特征图的大小:以第一个卷积层为例FLOPS=OutPut(Height)OutPut(Weight)Params=2242241728=86704128

2.2、全连接层 由于不存在权值共享,它的FLOPs数目即是该层参数数目: 以第一个全连接层为例 FLOPs=Params=102760448

3、总参数量 138,357,544,总FLOPs=15470314496

3、EXCEL详细结果

| | |
|--|--|
| | |

4、代码 import torch import torch.nn as nn from torchvision.models import vgg16 def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def count_flops(model, input_size): flops = 0 input = torch.randn(1, *input_size) def conv_hook(module, input, output): nonlocal flops batch_size, input_channels, input_height, input_width = input[0].size() output_channels, output_height, output_width = output[0].size() kernel_height, kernel_width = module.kernel_size flops += batch_size * output_channels * output_height * output_width * ( input_channels * kernel_height * kernel_width + 1) def fc_hook(module, input, output): nonlocal flops batch_size, input_features = input[0].size() output_features = output[0].size(0) # 修改这里 flops += batch_size * input_features * output_features hooks = [] for name, module in model.named_modules(): if isinstance(module, nn.Conv2d): hooks.append(module.register_forward_hook(conv_hook)) elif isinstance(module, nn.Linear): hooks.append(module.register_forward_hook(fc_hook)) model(input) for hook in hooks: hook.remove() return flops model = vgg16() params = count_parameters(model) flops = count_flops(model, (3, 224, 224)) print(f"Parameters: {params}") print(f"FLOPs: {flops}")

运行结果如下 在这里插入图片描述



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有