广告公司网站模板,三合一网站建设,软考网络规划设计师论文,免备案空间网站备案Double DQN算法
问题
DQN 算法通过贪婪法直接获得目标 Q 值#xff0c;贪婪法通过最大化方式使 Q 值快速向可能的优化目标收敛#xff0c;但易导致过估计Q 值的问题#xff0c;使模型具有较大的偏差。 即#xff1a; 对于DQN模型, 损失函数使用的 Q(state) reward Q(ne…Double DQN算法
问题
DQN 算法通过贪婪法直接获得目标 Q 值贪婪法通过最大化方式使 Q 值快速向可能的优化目标收敛但易导致过估计Q 值的问题使模型具有较大的偏差。 即 对于DQN模型, 损失函数使用的 Q(state) reward Q(nextState)max Q(state)由训练网络生成, Q(nextState)max由目标网络生成
这种损失函数会存在问题即当Q(nextState)max总是大于0时那么Q(state)总是在不停的增大同时Q(nextState)max也在不断的增大, 即Q(state)存在被高估的情况。
作者采用 Double DQN 算法解耦动作的选择和目标 Q 值的计算以解决过估计 Q 值的问题。
Double DQN 原理
Double DQN 算法结构如下。在 Double DQN 框架中存在两个神经网络模型分别是训练网络与目标网络。这两个神经网络模型的结构完全相同但是权重参数不同每训练一段之间后训练网络的权重参数才会复制给目标网络。训练时训练网络用于估计当前的 而目标网络用于估计 这样就能保证真实值 的估计不会随着训练网络的不断自更新而变化过快。此外DQN 还是一种支持离线学习的框架即通过构建经验池的方式离线学习过去的经验。将均方误差 MSE(Q_{train}, Q_{target}) 作为训练模型的损失函数通过梯度下降法进行反向传播对训练模型进行更新若干轮经验池采样后再将训练模型的权重赋给目标模型以此进行 Double DQN 框架下的模型自学习。
目标 Q 值的计算公式如下所示 y j r j γ max a ′ Q ( s j 1 , a ′ ; θ ′ ) y_jr_j\gamma \max _{a^{\prime}} Q\left(s_{j1}, a^{\prime} ; \theta^{\prime}\right) yjrjγa′maxQ(sj1,a′;θ′)
Double DQN 算法不直接通过最大化的方式选取目标网络计算的所有可能 Q Q Q 值而是首先通过估计网络选取最大 Q Q Q 值对应的动作公式表示如下: a max argmax a Q ( s t 1 , a ; θ ) a_{\max }\operatorname{argmax}_a Q\left(s_{t1}, a ; \theta\right) amaxargmaxaQ(st1,a;θ)
然后目标网络根据 a max a_{\max } amax 计算目标 Q 值公式表示如下: y j r j γ Q ( s j 1 , a max ; θ ′ ) y_jr_j\gamma Q\left(s_{j1}, a_{\max } ; \theta^{\prime}\right) yjrjγQ(sj1,amax;θ′)
最后将上面两个公式结合目标 Q Q Q 值的最终表示形式如下: y j r j γ Q ( s j 1 , argmax a Q ( s t 1 , a ; θ ) ; θ ′ ) y_jr_j\gamma Q\left(s_{j1}, \operatorname{argmax}_a Q\left(s_{t1, a ; \theta}\right) ; \theta^{\prime}\right) yjrjγQ(sj1,argmaxaQ(st1,a;θ);θ′)
目标是最小化目标函数即最小化估计 Q Q Q 值和目标 Q Q Q 值的差值公式如下: δ ∣ Q ( s t , a t ) − y t ∣ ∣ Q ( s t , a t ; θ ) − ( r t γ Q ( S t 1 , argmax a Q ( s t 1 , a ; θ ) ; θ ′ ) ) ∣ \begin{aligned} \delta\left|Q\left(s_t, a_t\right)-y_t\right|\mid Q\left(s_t, a_t ; \theta\right)-\left(r_t\right. \\ \left.\gamma Q\left(S_{t1}, \operatorname{argmax}_a Q\left(s_{t1}, a ; \theta\right) ; \theta^{\prime}\right)\right) \mid \end{aligned} δ∣Q(st,at)−yt∣∣Q(st,at;θ)−(rtγQ(St1,argmaxaQ(st1,a;θ);θ′))∣
结合目标函数损失函数定义如下: loss { 1 2 δ 2 for ∣ δ ∣ ⩽ 1 ∣ δ ∣ − 1 2 otherwize } \text { loss }\left\{\begin{array}{cl} \frac{1}{2} \delta^2 \text { for }|\delta| \leqslant 1 \\ |\delta|-\frac{1}{2} \text { otherwize } \end{array}\right\} loss {21δ2∣δ∣−21 for ∣δ∣⩽1 otherwize }
代码
游戏环境
import gym#定义环境
class MyWrapper(gym.Wrapper):def __init__(self):env gym.make(CartPole-v1, render_modergb_array)super().__init__(env)self.env envself.step_n 0def reset(self):state, _ self.env.reset()self.step_n 0return statedef step(self, action):state, reward, terminated, truncated, info self.env.step(action)over terminated or truncated#限制最大步数self.step_n 1if self.step_n 200:over True#没坚持到最后,扣分if over and self.step_n 200:reward -1000return state, reward, over#打印游戏图像def show(self):from matplotlib import pyplot as pltplt.figure(figsize(3, 3))plt.imshow(self.env.render())plt.show()env MyWrapper()env.reset()env.show()Q价值函数
import torch#定义模型,评估状态下每个动作的价值
model torch.nn.Sequential(torch.nn.Linear(4, 64),torch.nn.ReLU(),torch.nn.Linear(64, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2),
)#延迟更新的模型,用于计算target
model_delay torch.nn.Sequential(torch.nn.Linear(4, 64),torch.nn.ReLU(),torch.nn.Linear(64, 64),torch.nn.ReLU(),torch.nn.Linear(64, 2),
)#复制参数
model_delay.load_state_dict(model.state_dict())model, model_delay单条轨迹
from IPython import display
import random#玩一局游戏并记录数据
def play(showFalse):data []reward_sum 0state env.reset()over Falsewhile not over:action model(torch.FloatTensor(state).reshape(1, 4)).argmax().item()if random.random() 0.1:action env.action_space.sample()next_state, reward, over env.step(action)data.append((state, action, reward, next_state, over))reward_sum rewardstate next_stateif show:display.clear_output(waitTrue)env.show()return data, reward_sumplay()[-1]经验池
#数据池
class Pool:def __init__(self):self.pool []def __len__(self):return len(self.pool)def __getitem__(self, i):return self.pool[i]#更新动作池def update(self):#每次更新不少于N条新数据old_len len(self.pool)while len(pool) - old_len 200:self.pool.extend(play()[0])#只保留最新的N条数据self.pool self.pool[-2_0000:]#获取一批数据样本def sample(self):data random.sample(self.pool, 64)state torch.FloatTensor([i[0] for i in data]).reshape(-1, 4)action torch.LongTensor([i[1] for i in data]).reshape(-1, 1)reward torch.FloatTensor([i[2] for i in data]).reshape(-1, 1)next_state torch.FloatTensor([i[3] for i in data]).reshape(-1, 4)over torch.LongTensor([i[4] for i in data]).reshape(-1, 1)return state, action, reward, next_state, overpool Pool()
pool.update()
pool.sample()len(pool), pool[0]训练
#训练
def train():model.train()optimizer torch.optim.Adam(model.parameters(), lr2e-4)loss_fn torch.nn.MSELoss()#共更新N轮数据for epoch in range(1000):pool.update()#每次更新数据后,训练N次for i in range(200):#采样N条数据state, action, reward, next_state, over pool.sample()#计算valuevalue model(state).gather(dim1, indexaction)#计算targetwith torch.no_grad():target model_delay(next_state)target target.max(dim1)[0].reshape(-1, 1)target target * 0.99 * (1 - over) rewardloss loss_fn(value, target)loss.backward()optimizer.step()optimizer.zero_grad()#复制参数if (epoch 1) % 5 0:model_delay.load_state_dict(model.state_dict())if epoch % 100 0:test_result sum([play()[-1] for _ in range(20)]) / 20print(epoch, len(pool), test_result)train()