Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】 | 您所在的位置:网站首页 › pytorch半精度训练 › Pytorch Lightning框架:使用笔记【LightningModule、LightningDataModule、Trainer、ModelCheckpoint】 |
Pytorch Lightning官方手册 Pytorch Lightning源码:GitHub地址 Pytorch Lightning使用案例:Pytorch-Lightning-Template项目 GitHub地址 一、Pytorch Lightning 的流程Pytorch Lightning框架应用的流程很简单,生产流水线,有一个固定的顺序: 初始化 def init(self)训练training_step(self, batch, batch_idx)校验validation_step(self, batch, batch_idx)测试 test_step(self, batch, batch_idx)就完事了,总统是实现这三个函数的重写。 当然,除了这三个主要的,还有一些其他的函数,为了方便我们实现其他的一些功能,因此更为完整的流程是: 在training_step 后面都紧跟着其相应的 training_step_end(self,batch_parts)和training_epoch_end(self, training_step_outputs) 函数;validation_step 后面都紧跟着其相应的 validation_step_end(self,batch_parts)和validation_epoch_end(self, training_step_outputs) 函数;test_step 后面都紧跟着其相应的 test_step_end(self,batch_parts)和 test_epoch_end(self, training_step_outputs) 函数;这里以训练为例: def training_step(self, batch, batch_idx): x, y = batch y_hat = self.model(x) loss = F.cross_entropy(y_hat, y) pred = ... return {'loss': loss, 'pred': pred} def training_step_end(self, batch_parts): ''' 当gpus=0 or 1时,这里的batch_parts即为traing_step的返回值(已验证) 当gpus>1时,这里的batch_parts为list,list中每个为training_step返回值,list[i]为i号gpu的返回值(这里未验证) ''' gpu_0_prediction = batch_parts[0]['pred'] gpu_1_prediction = batch_parts[1]['pred'] # do something with both outputs return (batch_parts[0]['loss'] + batch_parts[1]['loss']) / 2 def training_epoch_end(self, training_step_outputs): ''' 当gpu=0 or 1时,training_step_outputs为list,长度为steps的数量(不包括validation的步数,当你训练时,你会发现返回list |
CopyRight 2018-2019 实验室设备网 版权所有 |