【一文学会】Gumbel 您所在的位置:网站首页 noises什么 【一文学会】Gumbel

【一文学会】Gumbel

2023-08-14 11:43| 来源: 网络整理| 查看: 265

目录

基于softmax的采样

基于gumbel-max的采样

基于gumbel-softmax的采样

基于ST-gumbel-softmax的采样

Gumbel分布

回答问题一

回答问题二

回答问题三

附录

 

以强化学习为例,假设网络输出的三维向量代表三个动作(前进、停留、后退)在下一步的收益,value=[-10,10,15],那么下一步我们就会选择收益最大的动作(后退)继续执行,于是输出动作[0,0,1]。选择值最大的作为输出动作,这样做本身没问题,但是在网络中这种取法有个问题是不能计算梯度,也就不能更新网络。

基于softmax的采样

这时通常的做法是加上softmax函数,把向量归一化,这样既能计算梯度,同时值的大小还能表示概率的含义(多项分布)。

                                                    \fn_phv \large \pi_k = \frac{e^{x_k}}{\sum_{i=1}^{K} e^{x_{i}}}

于是value=[-10,10,15]通过softmax函数后有σ(value)=[0,0.007,0.993],这样做不会改变动作或者说类别的选取,同时softmax倾向于让最大值的概率显著大于其他值,比如这里15和10经过softmax放缩之后变成了0.993和0.007,这有利于把网络训成一个one-hot输出的形式,这种方式在分类问题中是常用方法。

但这样就不会体现概率的含义了,因为σ(value)=[0,0.007,0.993]与σ(value)=[0.3,0.2,0.5]在类别选取的结果看来没有任何差别,都是选择第三个类别,但是从概率意义上讲差别是巨大的。

很直接的方法是依概率采样完事了,比如直接用np.random.choice函数依照概率生成样本值,这样概率就有意义了。所以,经典的采样方法就是用softmax函数加上轮盘赌方法(np.random.choice)。但这样还是会有个问题,这种方式怎么计算梯度?不能计算梯度怎么更新网络?

def sample_with_softmax(logits, size): # logits为输入数据 # size为采样数 pro = softmax(logits) return np.random.choice(len(logits), size, p=pro)

 

基于gumbel-max的采样

gumbel分布的具体介绍会放在后文,我们先看看结论。对于K维概率向量\large \alpha,对\large \alpha对应的离散变量x_{i}=log(\alpha _i)添加Gumbel噪声,再取样

                                   \large x=\mathop{argmax}_i(\log(\alpha _i)+G_i)

其中,\fn_phv \large G_i是独立同分布的标准Gumbel分布的随机变量,标准Gumbel分布的CDF为F(x)=e^{-e^{-x}}.所以\large G_i可以通过Gumbel分布求逆从均匀分布生成,即G_i=-\log(-\log(U_i)),U_i\sim U(0,1)x_{i}=log(\alpha _i)代入计算可知,这里的\large \alpha就是上面softmax采样的\large \pi,这样就得到了基于gumbel-max的采样过程:

对于网络输出的一个K维向量v,生成K个服从均匀分布U(0,1)的独立样本ϵ1,...,ϵK;

通过G_i=-\log(-\log(\varepsilon _i))计算得到G_i;

对应相加得到新的值向量v′=[v1+G1,v2+G2,...,vK+GK];

取最大值作为最终的类别

可以证明,gumbel-max 方法的采样效果等效于基于 softmax 的方式(后文也会证明)。由于 Gumbel 随机数可以预先计算好,采样过程也不需要计算 softmax,因此,某些情况下,gumbel-max 方法相比于 softmax,在采样速度上会有优势。当然,可以看到由于这中间有一个argmax操作,这是不可导的,依旧没法用于计算网络梯度。

def sa


【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有