Pytorch 您所在的位置:网站首页 pytorch模型参数修改 Pytorch

Pytorch

2023-07-04 14:09| 来源: 网络整理| 查看: 265

我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。

1.查看模型参数

现模型:

1 model_dict = model.state_dict() 2 for k,v in model_dict.items(): 3 print(k)

预训练模型参数

1 pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 2 for k,v in pretrained_dict.items(): 3 print(k) 2.将预训练参数赋给自己改进的模型 改进的模型参数和原模型参数一致时: 1 import torch.utils.model_zoo as model_zoo 2 3 model_urls = { 4 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth' 5 } 6 7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)

Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。

改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 : 1 model_dict = model.state_dict() #取出自己模型的网络参数 2 pretrained_dict = model_zoo.load_url(model_urls['resnet50']) 3 4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2] 5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有