RL-04-02-PyTorch实现要点

← 上级:RL-04.实现框架与实践 · 循环:RL-04-01-训练循环与接口约定

深度 RL 用 PyTorch 时,与监督学习的差异集中在 Bootstrap 目标双网络非 i.i.d. 数据。本文给出可复用的工程惯例。


一、推荐模块划分

1
2
3
4
5
6
7
8
9
10
rl_project/
├── config.yaml # 超参
├── envs.py # make_env + wrappers
├── buffers.py # Replay / Rollout
├── networks.py # QNet, ActorCritic
├── agents/
│ ├── dqn.py
│ └── ppo.py
├── train.py
└── utils.py # seed, log, soft_update

算法篇与 RL-05 数据结构 分文件,便于单测 Buffer。


二、设备与张量

1
2
3
4
5
6
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def to_tensor(x, dtype=torch.float32):
if isinstance(x, np.ndarray):
return torch.as_tensor(x, dtype=dtype, device=device)
return torch.as_tensor(x, dtype=dtype, device=device)
实践 说明
环境在 CPU step 快,避免 GPU 等 CPU
批量上 GPU obs_batch.to(device) 一次前向
float32 RL 默认足够

三、DQN 网络与 gather

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.nn as nn

class QNet(nn.Module):
def __init__(self, obs_dim, n_actions, hidden=128):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.ReLU(),
nn.Linear(hidden, hidden), nn.ReLU(),
nn.Linear(hidden, n_actions),
)

def forward(self, x):
return self.net(x)

# 训练时取 Q(s,a)
q_all = online_net(obs_batch) # (B, A)
q_sa = q_all.gather(1, act_batch.long().unsqueeze(1)).squeeze(1)

离散动作:输出每层 Q 值,用 gather 取执行动作的 Q。


四、目标网络同步

1
2
3
4
5
6
def hard_update(target, source):
target.load_state_dict(source.state_dict())

def soft_update(target, source, tau=0.005):
for tp, sp in zip(target.parameters(), source.parameters()):
tp.data.copy_(tau * sp.data + (1 - tau) * tp.data)
方式 用法
Hard 每 $C$ 步 hard_update(target, online)
Soft 每步 soft_update,SAC/TD3 常见

五、Bootstrap 目标(无梯度)

1
2
3
4
5
6
7
8
9
with torch.no_grad():
q_next = target_net(next_obs_batch).max(dim=1).values
target = rew_batch + gamma * (1.0 - done_batch) * q_next

loss = F.mse_loss(q_sa, target)
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(online_net.parameters(), max_norm=10.0)
optimizer.step()

关键:TD 目标一侧 detach,避免目标随 online 网络一起动。


六、PPO:分布与 log_prob

1
2
3
4
5
6
7
from torch.distributions import Categorical

logits = actor(obs)
dist = Categorical(logits=logits)
action = dist.sample()
log_prob = dist.log_prob(action)
entropy = dist.entropy()

连续动作用 Normal,动作经 tanh squash 时需修正 log_prob(SAC 做法)。


七、优化器与学习率

算法 常见设置
DQN Adam, lr=2.5e-4 或 RMSprop
PPO Adam, lr=3e-4, 常数或线性衰减
梯度裁剪 clip_grad_norm_(..., 0.5~10)

八、日志与 checkpoint

1
2
3
4
5
6
7
8
9
10
11
12
log_dict = {
"train/loss": loss.item(),
"train/epsilon": eps,
"train/ep_return": ep_return,
}
# tensorboard: writer.add_scalar(...)
torch.save({
"online": online_net.state_dict(),
"target": target_net.state_dict(),
"optimizer": optimizer.state_dict(),
"step": global_step,
}, "checkpoint.pt")

九、常见错误

错误 后果
目标未 no_grad 训练发散
done 未乘 $(1-d)$ 终止态错误 Bootstrap
忘记 optimizer.zero_grad 梯度累积异常
obs 未 float / 未归一化 学习极慢

十、小结

  • 模块:env / buffer / net / agent / train
  • DQN:gather + target + no_grad;PPO:dist.log_prob
  • 下一篇:表格型算法实现
-------------本文结束感谢您的阅读-------------