Pytorch断开后继续训练 或 加载预训练模型继续训练 您所在的位置:网站首页 PyTorch训练怎么暂停 Pytorch断开后继续训练 或 加载预训练模型继续训练

Pytorch断开后继续训练 或 加载预训练模型继续训练

2023-07-13 23:44| 来源: 网络整理| 查看: 265

在训练过程中,往往会遇到中断,如在Colab和Kaggle中,由于网络不稳定,很容易就断开了连接。然而,即使可以稳定训练,但是训练的时长往往是有上限的,此时我们的网络参数训练的可能还未收敛仍然需要训练,所以,应该加载原训练基础上再进行训练是十分很重要的。

比如,要训练1000代才能收敛,但是目前只训练的100代就中断了,所以要加载第100代训练的模型参数,然后训练接下来的900代

pytorch模型的保存机制

👉模型保存的两种机制

修改训练代码

中断的训练代码最简单的修改方式便是复制一份训练的代码,然后在其基础上进行修改,涉及到最重要的部分就是模型的保存与加载

🅰若优化器optimizer不需要随着训练的修改,那么直接加载模型、优化器,之后进行训练即可

🅱若优化器需要训练,那么可以进行一下修改:

if epoch == epochs_g + 1: optimizer_r.load_state_dict(checkpoint_r['optimizer']) optimizer_g.load_state_dict(checkpoint_g['optimizer']) lr_r = checkpoint_r['lr'] lr_g = checkpoint_g['lr'] else: optimizer_r = optim.Adagrad(model_r.parameters(), lr = lr_r, weight_decay = 1e-5) optimizer_g = optim.Adagrad(model_g.parameters(), lr = lr_g, weight_decay = 1e-5) 继续训练的第一次是利用模型保存下来的,而之后则是修改的优化器

如:我的模型每训练50次进行learning rate减半,初始学习率为0.001,而我的模型训练到第40代中断,所以加载第40代模型继续进行训练

python "train_continue.py" --pre_model_r './LapSRN_r_epoch_40.pt' --pre_model_g './LapSRN_g_epoch_40.pt' --nEpochs 60 --cuda --batchSize 1 --dataset "../../DataSet_test/"

可以看看优化器的变化如下:

Namespace(batchSize=1, cuda=True, dataset='../../DataSet_test/', lr=0.001, nEpochs=60, pre_model_g='./LapSRN_g_epoch_40.pt', pre_model_r='./LapSRN_r_epoch_40.pt', save_models='./', save_train_csv='./train.csv', save_val_csv='/val.csv', seed=123, valBatchSize=1) ===> Loading datasets ===> Loading pre_train model and Building model Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.001 lr_decay: 0 weight_decay: 1e-05 ) ===> Epoch 41 Complete: Avg. Loss: 0.0381 ===> Avg. PSNR1: 26.2686 dB ===> Avg. PSNR2: 25.1278 dB Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.001 lr_decay: 0 weight_decay: 1e-05 ) ===> Epoch 42 Complete: Avg. Loss: 0.0789 ===> Avg. PSNR1: 13.8764 dB ===> Avg. PSNR2: 16.7824 dB .........省略部分.......... Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.001 lr_decay: 0 weight_decay: 1e-05 ) ===> Epoch 49 Complete: Avg. Loss: 0.0749 ===> Avg. PSNR1: 25.5121 dB ===> Avg. PSNR2: 25.1218 dB Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.001 lr_decay: 0 weight_decay: 1e-05 ) ===> Epoch 50 Complete: Avg. Loss: 0.0877 ===> Avg. PSNR1: 28.2393 dB ===> Avg. PSNR2: 26.6869 dB Checkpoint saved to ./LapSRN_r_epoch_50.pt and ./LapSRN_g_epoch_50.pt Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.0005 lr_decay: 0 weight_decay: 1e-05 ) ===> Epoch 51 Complete: Avg. Loss: 0.2914 ===> Avg. PSNR1: 27.3521 dB ===> Avg. PSNR2: 25.3298 dB Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.0005 lr_decay: 0 weight_decay: 1e-05 ) ===> Epoch 52 Complete: Avg. Loss: 0.0505 ===> Avg. PSNR1: 21.9110 dB ===> Avg. PSNR2: 21.8041 dB Adagrad ( Parameter Group 0 eps: 1e-10 initial_accumulator_value: 0 lr: 0.0005 lr_decay: 0 weight_decay: 1e-05 ) 样例学习

中断训练train_continue.py代码如下,可供参考学习:



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有