【优化器】(六) AdamW原理 & pytorch代码解析 您所在的位置:网站首页 东莞公办小学临聘教师招聘 【优化器】(六) AdamW原理 & pytorch代码解析

【优化器】(六) AdamW原理 & pytorch代码解析

#【优化器】(六) AdamW原理 & pytorch代码解析| 来源: 网络整理| 查看: 265

1.简介

在之前的文章里,我们介绍了集成一阶动量和二阶动量的优化器Adam。AdamW其实是在Adam的基础上加入了weight decay正则化,但是我们上一篇文章里也看到了Adam的代码中已经有正则化,那么两者有什么区别呢?

2.AdamW

其实AdamW和Adam唯一的区别,就是weight decay的加入方式。

在Adam当中,weight decay是直接加入到梯度当中:

g_{t} = g_{t}+\lambda \theta _{t-1}

其中g_{t}是当前step的梯度,\theta _{t-1}是上一个step中的模型权重,\lambda是正则化系数。

而在AdamW中,正则化变成了:

\theta _{t} = \theta _{t-1}-\gamma \lambda \theta _{t-1}

其中\gamma是学习率。

所以AdamW的思路特别简单:反正正则化系数加进梯度之后最终也要在权重上进行更新,那为什么还需要加进梯度去呢?因此,AdamW直接在权重上进行衰减,在收敛速度上也能领先于Adam。

3.思考

但仔细一想,Adam+L2正则化和AdamW虽然都可以实现权重衰减,但是两者的实施细节上其实是不一样的。L2正则化是在loss上加入权重的惩罚系数,也可以说是在梯度上进行修改,而AdamW其实是更字面意思的weight decay,就是直接让权重衰减。

这两者其实在SGD上是对等的:

\theta _{t} = \theta _{t-1}-\gamma g_{t}^{weight}

= \theta _{t-1}-\gamma (g_{t}+\lambda \theta _{t-1}))

= \theta _{t-1}-\gamma \lambda \theta _{t-1}-\gamma g_{t}

只不过在Adam这种要考虑一阶和二阶动量时,以上方程已不满足线性关系,所以最终的结果是有区别的。那么AdamW相对于Adam而言,除了收敛速度更快之外,它的正则系数也不再受动量的影响(一般会被除以二阶动量而稀释),因此拥有超参独立和正则力度增加的优点,这也是原论文名字中带有decouple的原因。

4.pytorch代码

AdamW的伪代码流程如下:

以下代码为pytorch官方Adam的代码。

def _single_tensor_adamw( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], grad_scale: Optional[Tensor], found_inf: Optional[Tensor], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool, capturable: bool, differentiable: bool, ): assert grad_scale is None and found_inf is None for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] if capturable: assert ( param.is_cuda and step_t.is_cuda ), "If capturable=True, params and state_steps must be CUDA tensors." if torch.is_complex(param): grad = torch.view_as_real(grad) exp_avg = torch.view_as_real(exp_avg) exp_avg_sq = torch.view_as_real(exp_avg_sq) param = torch.view_as_real(param) # update step step_t += 1 # Perform stepweight decay param.mul_(1 - lr * weight_decay) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) if capturable or differentiable: step = step_t # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing") bias_correction1 = 1 - torch.pow(beta1, step) bias_correction2 = 1 - torch.pow(beta2, step) step_size = lr / bias_correction1 step_size_neg = step_size.neg() bias_correction2_sqrt = bias_correction2.sqrt() if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now if differentiable: max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone() else: max_exp_avg_sqs_i = max_exp_avg_sqs[i] max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq)) # Uses the max. for normalizing running avg. of gradient # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) denom = ( max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) ).add_(eps / step_size_neg) else: denom = ( exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) ).add_(eps / step_size_neg) param.addcdiv_(exp_avg, denom) else: step = _get_value(step_t) bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step step_size = lr / bias_correction1 bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) else: denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) param.addcdiv_(exp_avg, denom, value=-step_size)

业务合作/学习交流+v:lizhiTechnology



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有