Torch 您所在的位置:网站首页 tp通道数怎么定义 Torch

Torch

2023-11-21 22:02| 来源: 网络整理| 查看: 265

Torch-Pruning 通道剪枝网络实现加速的工作。 image Torch pruning是进行结构剪枝的pytorch工具箱,和pytorch官方提供的基于mask的非结构化剪枝不同,工具箱移除整个通道剪枝,自动发现层与层剪枝的依赖关系,可以处理Densenet、ResNet和DeepLab

特性

卷积网络通道剪枝 CNNs (e.g. ResNet, DenseNet, Deeplab) 和 Transformers (即 Bert, @horseee贡献代码)

网络图跟踪以及依赖关系. 支持网络层: Conv, Linear, BatchNorm, LayerNorm, Transposed Conv, PReLU, Embedding 和 扩展层. 支持操作: split, concatenation, skip connection, flatten, 等等. 剪枝策略: Random, L1, L2, 等等. 它是怎样工作的

Torch-Pruning 使用 fake inputs输入网络和torch.jit一样收集网络信息.

dependency graph 用来表示计算图和层之间的关系. 由于裁剪一层会影响若干层 , dependecy会自动传播剪枝到其他层并且保存在PruningPlan.

如果模型中有 torch.split或者torch.cat,所有剪枝的indices都会做一些变换的

Conv-Conv:\(n_{i+1}\) oc中减少1个通道,下一个卷积每个通oc通道中ic通道\(n_{i+1}\)少一个 Skip Connection: 需要考虑ic和上一层的oc互相关联,所以这里shortcut和add都需要传递这种关联。

依赖关系 可视化 例子 Conv-Conv image AlexNet Conv-FC(Global Pooling or Flatten) image ResNet,VGG Skip Connection image ResNet Concatenation image DenseNet, ASPP Split image torch.chunk 一个例子

先来看下torchpruning 的流程图: image

# 1. setup strategy (L1 Norm) strategy = tp.strategy.L1Strategy() # or tp.strategy.RandomStrategy() # 2. build layer dependency for resnet18 DG = tp.DependencyGraph() DG.build_dependency(model, example_inputs=torch.randn(1,3,224,224)) # 3. get a pruning plan from the dependency graph. pruning_idxs = strategy(model.conv1.weight, amount=0.4, round_to=16) # or manually selected pruning_idxs=[2, 6, 9, ...] pruning_plan = DG.get_pruning_plan( model.conv1, tp.prune_conv, idxs=pruning_idxs ) print(pruning_plan) # 4. execute this plan (prune the model) pruning_plan.exec() print(model)

image

pruning_plan = DG.get_pruning_plan( pruning_idxs ): image

底层剪枝函数

使用一层一层的固定剪枝和上面是等价的

tp.prune_conv( model.conv1, idxs=[2,6,9] ) # fix the broken dependencies manually tp.prune_batchnorm( model.bn1, idxs=[2,6,9] ) tp.prune_related_conv( model.layer2[0].conv1, idxs=[2,6,9] )

运行结果:

(Conv2d(36, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False), 3456) 对设备友好的通道对齐剪枝

可以通过设置round_to参数,下例可以使得通道对16取整(即,16,32,48,64)

strategy = tp.strategy.L1Strategy() pruning_idxs = strategy(model.conv1.weight, amount=0.2, round_to=16)

image

image

本文暂时没有对torch pruning源码进行分析,先学会使用,后续如果有需要、有时间会再进行源码分析



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有