RL-04-04-DQN实现

← 上级:RL-04.实现框架与实践 · 算法:RL-03-06-算法-DQN · Buffer:RL-05-03-结构-Replay-Buffer

本文给出 CartPole-v1 上可运行的 DQN 最小完整实现(单文件可拆模块),对应 RL-04-02-PyTorch实现要点 中的惯例。


一、Replay Buffer

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import random
from collections import deque
import numpy as np

class ReplayBuffer:
def __init__(self, capacity: int):
self.buf = deque(maxlen=capacity)

def push(self, state, action, reward, next_state, done):
self.buf.append((state, action, reward, next_state, done))

def sample(self, batch_size: int):
batch = random.sample(self.buf, batch_size)
s, a, r, s2, d = map(np.array, zip(*batch))
return s, a, r, s2, d.astype(np.float32)

def __len__(self):
return len(self.buf)

二、Q 网络与 Agent

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.nn as nn
import torch.nn.functional as F

class QNet(nn.Module):
def __init__(self, obs_dim, n_actions, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions),
)

def forward(self, x):
return self.net(x)


class DQNAgent:
def __init__(self, obs_dim, n_actions, device, gamma=0.99, lr=1e-3):
self.device = device
self.gamma = gamma
self.n_actions = n_actions
self.online = QNet(obs_dim, n_actions).to(device)
self.target = QNet(obs_dim, n_actions).to(device)
self.target.load_state_dict(self.online.state_dict())
self.optimizer = torch.optim.Adam(self.online.parameters(), lr=lr)
self.buffer = ReplayBuffer(50_000)
self.eps = 1.0
self.eps_min = 0.05
self.eps_decay = 0.995
self.learn_step = 0
self.target_update = 500

def select_action(self, obs):
if np.random.random() < self.eps:
return np.random.randint(self.n_actions)
with torch.no_grad():
x = torch.as_tensor(obs, dtype=torch.float32, device=self.device).unsqueeze(0)
return int(self.online(x).argmax(dim=1).item())

def store(self, s, a, r, s2, done):
self.buffer.push(s, a, r, s2, done)

def update(self, batch_size=64):
if len(self.buffer) < batch_size:
return None
s, a, r, s2, d = self.buffer.sample(batch_size)
s = torch.as_tensor(s, dtype=torch.float32, device=self.device)
a = torch.as_tensor(a, dtype=torch.int64, device=self.device)
r = torch.as_tensor(r, dtype=torch.float32, device=self.device)
s2 = torch.as_tensor(s2, dtype=torch.float32, device=self.device)
d = torch.as_tensor(d, dtype=torch.float32, device=self.device)

q = self.online(s).gather(1, a.unsqueeze(1)).squeeze(1)
with torch.no_grad():
q2 = self.target(s2).max(dim=1).values
target = r + self.gamma * (1.0 - d) * q2

loss = F.mse_loss(q, target)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

self.learn_step += 1
if self.learn_step % self.target_update == 0:
self.target.load_state_dict(self.online.state_dict())
return loss.item()

def decay_eps(self):
self.eps = max(self.eps_min, self.eps * self.eps_decay)

三、训练主循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import gymnasium as gym

def train_dqn(total_steps=50_000):
env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

agent = DQNAgent(obs_dim, n_actions, device)
obs, _ = env.reset()
ep_ret, ep_len = 0.0, 0
returns = []

for step in range(1, total_steps + 1):
action = agent.select_action(obs)
next_obs, reward, term, trunc, _ = env.step(action)
done = term or trunc

agent.store(obs, action, reward, next_obs, done)
loss = agent.update()

ep_ret += reward
ep_len += 1
obs = next_obs

if done:
returns.append(ep_ret)
agent.decay_eps()
obs, _ = env.reset()
ep_ret, ep_len = 0.0, 0

if step % 5000 == 0 and returns:
print(f"step={step}, avg_return={np.mean(returns[-20:]):.1f}, eps={agent.eps:.3f}")

env.close()
return agent

if __name__ == "__main__":
train_dqn()

四、改进挂钩

改进 改法
Double DQN a* = online(s2).argmax; target = r + γ * target(s2)[a*]
Dueling 网络拆 $V + A$
PER Prioritized Buffer

五、超参(CartPole)

超参
buffer 50000
batch 64
lr 1e-3
$\gamma$ 0.99
target_update 500
warmup 1000 步后再 update(可加)

六、小结

-------------本文结束感谢您的阅读-------------