PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型

您所在的位置:网站首页 mastercam91设置的参数怎么保存 PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型

PyTorch 预训练模型,保存,读取和更新模型参数以及多 GPU 训练模型

2024-06-30 20:48:07| 来源: 网络整理| 查看: 265

本文用于记录如何进行 PyTorch 所提供的预训练模型应如何加载,所训练模型的参数应如何保存与读取,如何冻结模型部分参数以方便进行 fine-tuning 以及如何利用多 GPU 训练模型。

1. PyTorch 预训练模型

Pytorch 提供了许多 Pre-Trained Model on ImageNet,仅需调用 torchvision.models 即可,具体细节可查看官方文档。

往往我们需要对 Pre-Trained Model 进行相应的修改,以适应我们的任务。这种情况下,我们可以先输出 Pre-Trained Model 的结构,确定好对哪些层修改,或者添加哪些层,接着,再将其修改即可。

比如,我需要将 ResNet-50 的 Layer 3 后的所有层去掉,在分别连接十个分类器,分类器由 ResNet-50.layer4 和 AvgPool Layer 和 FC Layer 构成。这里就需要用到 torch.nn.ModuleList 了,比如:self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])。

代码中的 [nn.Linear(10, 10) for i in range(10)] 是一个python列表,必须要把它转换成一个Module Llist列表才可以被 PyTorch 使用,否则在运行的时候会报错: RuntimeError: Input type (CUDAFloatTensor) and weight type (CPUFloatTensor) should be the same

2. 保存模型参数

PyTorch 中保存模型的方式有许多种:

# 保存整个网络 torch.save(model, PATH) # 保存网络中的参数, 速度快,占空间少 torch.save(model.state_dict(),PATH) # 选择保存网络中的一部分参数或者额外保存其余的参数 torch.save({'state_dict': model.state_dict(), 'fc_dict':model.fc.state_dict(), 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma}, PATH) 3. 读取模型参数

同样的,PyTorch 中读取模型参数的方式也有许多种:

# 读取整个网络 model = torch.load(PATH) # 读取 Checkpoint 中的网络参数 model.load_state_dict(torch.load(PATH)) # 若 Checkpoint 中的网络参数与当前网络参数有部分不同,有以下两种方式进行加载: # 1. 利用字典的 update 方法进行加载 Checkpoint = torch.load(Path) model_dict = model.state_dict() model_dict.update(Checkpoint) model.load_state_dict(model_dict) # 2. 利用 load_state_dict() 的 strict 参数进行部分加载 model.load_state_dict(torch.load(PATH), strict=False) 4. 冻结部分模型参数,进行 fine-tuning

加载完 Pre-Trained Model 后,我们需要对其进行 Finetune。但是在此之前,我们往往需要冻结一部分的模型参数:

# 第一种方式 for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False p.requires_grad = False for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True p.requires_grad = True # 将需要 fine-tuning 的参数放入optimizer 中 optimizer.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3) # 第二种方式 optim_param = [] for p in freeze.parameters(): # 将需要冻结的参数的 requires_grad 设置为 False p.requires_grad = False for p in no_freeze.parameters(): # 将fine-tuning 的参数的 requires_grad 设置为 True p.requires_grad = True optim_param.append(p) optimizer.SGD(optim_param, lr=1e-3) # 将需要 fine-tuning 的参数放入optimizer 中 5. 模型训练与测试的设置

训练时,应调用 model.train() ;测试时,应调用 model.eval(),以及 with torch.no_grad():

model.train():使 model 变成训练模式,此时 dropout 和 batch normalization 的操作在训练起到防止网络过拟合的问题。

model.eval():PyTorch会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。不然的话,一旦测试集的 Batch Size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。

with torch.no_grad():PyTorch 将不再计算梯度,这将使得模型 forward 的时候,显存的需求大幅减少,速度大幅提高。

注意:若模型中具有 Batch Normalization 操作,想固定该操作进行训练时,需调用对应的 module 的 eval() 函数。这是因为 BN Module 除了参数以外,还会对输入的数据进行统计,若不调用 eval(),统计量将发生改变!具体代码可以这样写:

for module in model.modules(): module.eval()

在其他地方看到的解释:

model.eval() will notify all your layers that you are in eval mode, that way, batchnorm or dropout layers will work in eval model instead of training mode.torch.no_grad() impacts the autograd engine and deactivate it. It will reduce memory usage and speed up computations but you won’t be able to backprop (which you don’t want in an eval script). 6. 利用 torch.nn.DataParallel 进行多 GPU 训练 import torch import torch.nn as nn import torchvision.models as models # 生成模型 # 利用 torch.nn.DataParallel 进行载入模型,默认使用所有GPU(可以用 CUDA_VISIBLE_DEVICES 设置所使用的 GPU) model = nn.DataParallel(models.resnet18()) # 冻结参数 for param in model.module.layer4.parameters(): param.requires_grad = False param_optim = filter(lambda p:p.requires_grad, model.parameters()) # 设置测试模式 model.module.layer4.eval() # 保存模型参数(读取所保存模型参数后,再进行并行化操作,否则无法利用之前的代码进行读取) torch.save(model.module.state_dict(),'./CheckPoint.pkl') 学习交流

目前开通了技术交流群,群友超过500人,添加时最好备注形式为:来源+兴趣方向,方便找到志同道合的朋友

方式1、发送如下图片至微信,长按识别,关注后台回复:加群;方式2、微信搜索公众号:机器学习社区,关注后台回复:加群;

扫描关注



【本文地址】

公司简介

联系我们

今日新闻


点击排行

实验室常用的仪器、试剂和
说到实验室常用到的东西,主要就分为仪器、试剂和耗
不用再找了,全球10大实验
01、赛默飞世尔科技(热电)Thermo Fisher Scientif
三代水柜的量产巅峰T-72坦
作者:寞寒最近,西边闹腾挺大,本来小寞以为忙完这
通风柜跟实验室通风系统有
说到通风柜跟实验室通风,不少人都纠结二者到底是不
集消毒杀菌、烘干收纳为一
厨房是家里细菌较多的地方,潮湿的环境、没有完全密
实验室设备之全钢实验台如
全钢实验台是实验室家具中较为重要的家具之一,很多

推荐新闻


图片新闻

实验室药品柜的特性有哪些
实验室药品柜是实验室家具的重要组成部分之一,主要
小学科学实验中有哪些教学
计算机 计算器 一般 打孔器 打气筒 仪器车 显微镜
实验室各种仪器原理动图讲
1.紫外分光光谱UV分析原理:吸收紫外光能量,引起分
高中化学常见仪器及实验装
1、可加热仪器:2、计量仪器:(1)仪器A的名称:量
微生物操作主要设备和器具
今天盘点一下微生物操作主要设备和器具,别嫌我啰嗦
浅谈通风柜使用基本常识
 众所周知,通风柜功能中最主要的就是排气功能。在

专题文章

    CopyRight 2018-2019 实验室设备网 版权所有 win10的实时保护怎么永久关闭