← 上级:RL-04.实现框架与实践 · 循环:RL-04-01-训练循环与接口约定
深度 RL 用 PyTorch 时,与监督学习的差异集中在 Bootstrap 目标、双网络、非 i.i.d. 数据。本文给出可复用的工程惯例。
一、推荐模块划分
1 | rl_project/ |
算法篇与 RL-05 数据结构 分文件,便于单测 Buffer。
二、设备与张量
1 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 实践 | 说明 |
|---|---|
| 环境在 CPU | step 快,避免 GPU 等 CPU |
| 批量上 GPU | obs_batch.to(device) 一次前向 |
float32 |
RL 默认足够 |
三、DQN 网络与 gather
1 | import torch.nn as nn |
离散动作:输出每层 Q 值,用 gather 取执行动作的 Q。
四、目标网络同步
1 | def hard_update(target, source): |
| 方式 | 用法 |
|---|---|
| Hard | 每 $C$ 步 hard_update(target, online) |
| Soft | 每步 soft_update,SAC/TD3 常见 |
五、Bootstrap 目标(无梯度)
1 | with torch.no_grad(): |
关键:TD 目标一侧 detach,避免目标随 online 网络一起动。
六、PPO:分布与 log_prob
1 | from torch.distributions import Categorical |
连续动作用 Normal,动作经 tanh squash 时需修正 log_prob(SAC 做法)。
七、优化器与学习率
| 算法 | 常见设置 |
|---|---|
| DQN | Adam, lr=2.5e-4 或 RMSprop |
| PPO | Adam, lr=3e-4, 常数或线性衰减 |
| 梯度裁剪 | clip_grad_norm_(..., 0.5~10) |
八、日志与 checkpoint
1 | log_dict = { |
九、常见错误
| 错误 | 后果 |
|---|---|
目标未 no_grad |
训练发散 |
done 未乘 $(1-d)$ |
终止态错误 Bootstrap |
忘记 optimizer.zero_grad |
梯度累积异常 |
| obs 未 float / 未归一化 | 学习极慢 |
十、小结
- 模块:env / buffer / net / agent / train。
- DQN:gather + target + no_grad;PPO:dist.log_prob。
- 下一篇:表格型算法实现