Pytorch深度学习实战教程(四):必知必会的炼丹法宝 您所在的位置:网站首页 pytorch中的loss Pytorch深度学习实战教程(四):必知必会的炼丹法宝

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

2022-03-26 05:49| 来源: 网络整理| 查看: 265

摘要

人手不够,“法宝”来凑。本文就盘点一下,我们可以使用的「炼丹法宝」。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

一、前言

训练深度学习模型,就像“炼丹”,模型可能需要训练很多天。

我们不可能像「太上老君」那样,拿着浮尘,24 小时全天守在「八卦炉」前,更何况人家还有炼丹童、天兵天将,轮流值守。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

人手不够,“法宝”来凑。

本文就盘点一下,我们可以使用的「炼丹法宝」。

PS:文中出现的所有代码,均可在我的 Github 上下载:点击查看

二、初级“法宝”,sys.stdout

训练模型,最常看的指标就是 Loss。我们可以根据 Loss 的收敛情况,初步判断模型训练的好坏。

如果,Loss 值突然上升了,那说明训练有问题,需要检查数据和代码。

如果,Loss 值趋于稳定,那说明训练完毕了。

观察 Loss 情况,最直观的方法,就是绘制 Loss 曲线图。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

通过绘图,我们可以很清晰的看到,左图还有收敛空间,而右图已经完全收敛。

通过 Loss 曲线,我们可以分析模型训练的好坏,模型是否训练完成,起到一个很好的“监控”作用。

绘制 Loss 曲线图,第一步就是需要保存训练过程中的 Loss 值。

一个最简单的方法是使用,sys.stdout 标准输出重定向,简单好用,实乃“炼丹”必备“良宝”。

Python12345678910111213141516171819import osimport sysclass Logger():    def __init__(self, filename="log.txt"):        self.terminal = sys.stdout        self.log = open(filename, "w")     def write(self, message):        self.terminal.write(message)        self.log.write(message)     def flush(self):        pass sys.stdout = Logger() print("Jack Cui")print("https://cuijiahua.com")print("https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA")

代码很简单,创建一个 log.py 文件,自己写一个 Logger 类,并采用 sys.stdout 重定向输出。

在 Terminal 中,不仅可以使用 print 打印结果,同时也会将结果保存到 log.txt 文件中。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

运行 log.py,打印 print 内容的同时,也将内容写入了 log.txt 文件中。

使用这个代码,就可以在打印 Loss 的同时,将结果保存到指定的 txt 中,比如保存上篇文章训练 UNet 的 Loss。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

三、中级“法宝”,matplotlib

Matplotlib 是一个 Python 的绘图库,简单好用。

简单几行命令,就可以绘制曲线图、散点图、条形图、直方图、饼图等等。

在深度学习中,一般就是绘制曲线图,比如 Loss 曲线、Acc 曲线。

举一个,简单的例子。

使用 sys.stdout 保存的 train_loss.txt,绘制 Loss 曲线。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

train_loss.txt 下载地址:点击查看

思路非常简单,读取 txt 内容,解析 txt 内容,使用 Matplotlib 绘制曲线。

Python12345678910111213import matplotlib.pyplot as plt# Jupyter notebook 中开启# %matplotlib inlinewith open('train_loss.txt', 'r') as f:    train_loss = f.readlines()    train_loss = list(map(lambda x:float(x.strip()), train_loss))x = range(len(train_loss))y = train_lossplt.plot(x, y, label='train loss', linewidth=2, color='r', marker='o', markerfacecolor='r', markersize=5)plt.xlabel('Epoch')plt.ylabel('Loss Value')plt.legend()plt.show()

指定 x 和 y 对应的值,就可以绘制。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

是不是很简单?

关于 Matplotlib 更多的详细教程,可以查看官方手册:点击查看

四、中级“法宝”,Logging

说到保存日志,那不得不提 Python 的内置标准模块 Logging,它主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等,同时,我们也可以设置日志的输出格式。

Python12345678910111213141516171819import logging def get_logger(LEVEL, log_file = None):    head = '[%(asctime)-15s] [%(levelname)s] %(message)s'    if LEVEL == 'info':        logging.basicConfig(level=logging.INFO, format=head)    elif LEVEL == 'debug':        logging.basicConfig(level=logging.DEBUG, format=head)    logger = logging.getLogger()    if log_file != None:        fh = logging.FileHandler(log_file)        logger.addHandler(fh)    return logger logger = get_logger('info') logger.info('Jack Cui')logger.info('https://cuijiahua.com')logger.info('https://mp.weixin.qq.com/s/OCWwRVDFNslIuKyiCVUoTA')

只需要几行代码,进行一个简单的封装使用。使用函数 get_logger 创建一个级别为 info 的 logger,如果指定 log_file,则会对日志进行保存。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

logging 默认支持的日志一共有 5 个等级:

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

日志级别等级 CRITICAL > ERROR > WARNING > INFO > DEBUG。

默认的日志级别设置为 WARNING,也就是说如果不指定日志级别,只会显示大于等于 WARNING 级别的日志。

例如:

Python123456import logginglogging.debug("debug_msg")logging.info("info_msg")logging.warning("warning_msg")logging.error("error_msg")logging.critical("critical_msg")

运行结果:

Python123WARNING:root:warning_msgERROR:root:error_msgCRITICAL:root:critical_msg

可以看到 info 和 debug 级别的日志不会输出,默认的日志格式也比较简单。

Python1默认的日志格式为日志级别:Logger名称:用户输出消息

当然,我们可以通过,logging.basicConfig 的 format 参数,设置日志格式。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

字段有很多,可谓应有尽有,足以满足我们定制化的需求。

五、高级“法宝”,TensorboardX

上文介绍的“法宝”,并非针对深度学习“炼丹”使用的工具。

而 TensorboardX 则不同,它是专门用于深度学习“炼丹”的高级“法宝”。

早些时候,很多人更喜欢用 Tensorflow 的原因之一,就是 Tensorflow 框架有个一个很好的可视化工具 Tensorboard。

Pytorch 要想使用 Tensorboard 配置起来费劲儿不说,还有很多 Bug。

Pytorch 1.1.0 版本发布后,打破了这个局面,TensorBoard 成为了 Pytorch 的正式可用组件。

在 Pytorch 中,这个可视化工具叫做 TensorBoardX,其实就是针对 Tensorboard 的一个封装,使得 PyTorch 用户也能够调用 Tensorboard。

TensorboardX 安装也非常简单,使用 pip 即可安装,需要注意的是 Pytorch 的版本需要大于 1.1.0。

Shell1pip install tensorboardX

tensorboardX 使用也很简单,编写如下代码。

Python12345678910111213from tensorboardX import SummaryWriter # 创建 writer1 对象# log 会保存到 runs/exp 文件夹中writer1 = SummaryWriter('runs/exp') # 使用默认参数创建 writer2 对象# log 会保存到 runs/日期_用户名 格式的文件夹中writer2 = SummaryWriter() # 使用 commet 参数,创建 writer3 对象# log 会保存到 runs/日期_用户名_resnet 格式的文件中writer3 = SummaryWriter(comment='_resnet')

使用的时候,创建一个 SummaryWriter 对象即可,以上展示了三种初始化 SummaryWriter 的方法:

提供一个路径,将使用该路径来保存日志无参数,默认将使用 runs/日期_用户名 路径来保存日志提供一个 comment 参数,将使用 runs/日期_用户名+comment 路径来保存日志

运行结果:

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

有了 writer 我们就可以往日志里写入数字、图片、甚至声音等数据。

数字 (scalar)

这个是最简单的,使用 add_scalar 方法来记录数字常量。

Python1add_scalar(tag, scalar_value, global_step=None, walltime=None)

总共 4 个参数。

tag (string): 数据名称,不同名称的数据使用不同曲线展示scalar_value (float): 数字常量值global_step (int, optional): 训练的 stepwalltime (float, optional): 记录发生的时间,默认为 time.time()

需要注意,这里的 scalar_value 一定是 float 类型,如果是 PyTorch scalar tensor,则需要调用 .item() 方法获取其数值。我们一般会使用 add_scalar 方法来记录训练过程的 loss、accuracy、learning rate 等数值的变化,直观地监控训练过程。

运行如下代码:

Python123456from tensorboardX import SummaryWriter    writer = SummaryWriter('runs/scalar_example')for i in range(10):    writer.add_scalar('quadratic', i**2, global_step=i)    writer.add_scalar('exponential', 2**i, global_step=i)writer.close()

通过 add_scalar 往日志里写入数字,日志保存到 runs/scalar_example中,writer 用完要记得 close,否则无法保存数据。

在 cmd 中使用如下命令:

Shell1tensorboard --logdir=runs/scalar_example --port=8088

指定日志地址,使用端口号,在浏览器中,就可以使用如下地址,打开 Tensorboad。

Shell1http://localhost:8088/

省去了我们自己写代码可视化的麻烦。

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

图片 (image)

使用 add_image 方法来记录单个图像数据。注意,该方法需要 pillow 库的支持。

Shell1add_image(tag, img_tensor, global_step=None, walltime=None, dataformats='CHW')

参数:

tag (string):数据名称img_tensor (torch.Tensor / numpy.array):图像数据global_step (int, optional):训练的 stepwalltime (float, optional):记录发生的时间,默认为 time.time()dataformats (string, optional):图像数据的格式,默认为 'CHW',即 Channel x Height x Width,还可以是 'CHW'、'HWC' 或 'HW' 等

我们一般会使用 add_image 来实时观察生成式模型的生成效果,或者可视化分割、目标检测的结果,帮助调试模型。

Python123456789101112131415from tensorboardX import SummaryWriterfrom urllib.request import urlretrieveimport cv2 urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/0.png',filename = '1.jpg')urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/1.png',filename = '2.jpg')urlretrieve(url = 'https://raw.githubusercontent.com/Jack-Cherish/Deep-Learning/master/Pytorch-Seg/lesson-2/data/train/label/2.png',filename = '3.jpg') writer = SummaryWriter('runs/image_example')for i in range(1, 4):    writer.add_image('UNet_Seg',                     cv2.cvtColor(cv2.imread('{}.jpg'.format(i)), cv2.COLOR_BGR2RGB),                     global_step=i,                     dataformats='HWC')writer.close()

代码就是下载上篇文章数据集里的三张图片,然后使用 Tensorboard 可视化处理来,使用 8088 端口开打 Tensorboard:

Shell1tensorboard --logdir=runs/image_example --port=8088

运行结果:

Pytorch深度学习实战教程(四):必知必会的炼丹法宝

试想一下,一边训练,一边输出图片结果,是不是很酸爽呢?

Tensorboard 中常用的 Scalar 和 Image,直方图、运行图、嵌入向量等,可以查看官方手册进行学习,方法都是类似的,简单好用。

官方文档:点击查看

六、总结

工欲善其事,必先利其器。

本文讲解了深度学习中,常用的“炼丹法宝”的使用方法,sys.stdout、matplotlib、logging、tensorboardX 你更喜欢哪一款?



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有