多智能体强化学习调代码有感(三) 您所在的位置:网站首页 argmax和softmax 多智能体强化学习调代码有感(三)

多智能体强化学习调代码有感(三)

2023-03-05 00:34| 来源: 网络整理| 查看: 265

建议提前食用:

下面以https://github.com/shariqiqbal2810/maddpg-pytorch为例进行说明:

gumbel_softmax代码位于:https://github.com/shariqiqbal2810/maddpg-pytorch/blob/master/utils/misc.py

调用gumbel_softmax的行为选择函数step位于:https://github.com/shariqiqbal2810/maddpg-pytorch/blob/master/utils/agents.py

在maddpg中,用gumbel_softmax方法,在原始数据上增加Gumbel噪声,实现随机探索。

# modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb def gumbel_softmax_sample(logits, temperature): """ Draw a sample from the Gumbel-Softmax distribution""" y = logits + sample_gumbel(logits.shape, tens_type=type(logits.data)) #添加噪声 return F.softmax(y / temperature, dim=1) # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb def gumbel_softmax(logits, temperature=1.0, hard=False): """Sample from the Gumbel-Softmax distribution and optionally discretize. Args: logits: [batch_size, n_class] unnormalized log-probs temperature: non-negative scalar hard: if True, take argmax, but differentiate w.r.t. soft sample y Returns: [batch_size, n_class] sample from the Gumbel-Softmax distribution. If hard=True, then the returned sample will be one-hot, otherwise it will be a probabilitiy distribution that sums to 1 across classes """ y = gumbel_softmax_sample(logits, temperature) #temperature 是在大于零的参数,它控制着 softmax 的 soft 程度。温度越高,生成的分布越平滑;温度越低,生成的分布越接近离散的 one-hot 分布。训练中,可以通过逐渐降低温度,以逐步逼近真实的离散分布。 if hard: y_hard = onehot_from_logits(y) y = (y_hard - y).detach() + y return y

在maddpg中,policy网络输出的值在-1到1之间,所以在进行gumbel-softmax采样的时候,就把网络输出作为行为价值的对数结果(并不是行为概率之类的,而是代表log)输入到gumbel_softmax(hard)中得到onehot形式的行为选择。

def step(self, obs, explore=False): """ Take a step forward in environment for a minibatch of observations Inputs: obs (PyTorch Variable): Observations for this agent explore (boolean): Whether or not to add exploration noise Outputs: action (PyTorch Variable): Actions for this agent """ action = self.policy(obs) #-1到1之间 if self.discrete_action: if explore: action = gumbel_softmax(action, hard=True) #探索 else: action = onehot_from_logits(action) else: # continuous action if explore: action += Variable(Tensor(self.exploration.noise()), requires_grad=False) action = action.clamp(-1, 1) return action



【本文地址】

公司简介

联系我们

今日新闻

    推荐新闻

    专题文章
      CopyRight 2018-2019 实验室设备网 版权所有