RL-04-01-训练循环与接口约定

← 上级:RL-04.实现框架与实践 · 评估:RL-06.评估环境与工具链

所有自研 RL 代码建议先统一 环境交互契约,再叠算法。本文约定与 Gymnasium 对齐的接口与训练循环骨架。


一、最小交互模板

1
2
3
4
5
6
7
8
9
10
11
import gymnasium as gym

env = gym.make("CartPole-v1")
obs, info = env.reset(seed=42)

terminated = truncated = False
while not (terminated or truncated):
action = env.action_space.sample()
obs, reward, terminated, truncated, info = env.step(action)

env.close()
调用 返回值 说明
reset(seed=) obs, info 新 episode 起点
step(action) obs, reward, terminated, truncated, info 一步转移

done 的拆分(Gymnasium ≥0.26):

标志 含义
terminated MDP 自然终止(失败/成功)
truncated 时间步上限等外部截断

Bootstrap 时:$y = r + \gamma Q(s’)$ 仅在 非终止 时使用 $Q(s’)$;truncated 时通常仍 Bootstrap(任务可继续)。


二、Space 类型

1
2
env.observation_space  # Box / Discrete / Dict / Tuple
env.action_space
类型 示例环境 Agent 处理
Discrete(n) CartPole 动作 0/1 直接索引
Box(low, high, shape) 连续观测/动作 归一化、裁剪
Dict 多模态 分头编码再拼接
1
2
3
4
import numpy as np
obs, _ = env.reset()
assert env.observation_space.contains(obs)
a = env.action_space.sample()

三、训练循环抽象

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
def train_loop(env, agent, total_steps: int):
obs, _ = env.reset()
ep_return, ep_len = 0.0, 0

for step in range(1, total_steps + 1):
action = agent.select_action(obs)
next_obs, reward, term, trunc, info = env.step(action)
done = term or trunc

agent.store(obs, action, reward, next_obs, done)
agent.maybe_update()

ep_return += reward
ep_len += 1
obs = next_obs

if done:
agent.on_episode_end(ep_return, ep_len)
obs, _ = env.reset()
ep_return, ep_len = 0.0, 0
钩子 职责
select_action $\varepsilon$-greedy / $\pi_\theta$ 采样
store 写 Q 表 / Replay / Rollout
maybe_update 满足 warmup、batch 条件时更新
on_episode_end 日志、$\varepsilon$ 衰减

四、Wrapper 常用链

1
2
3
4
from gymnasium.wrappers import RecordEpisodeStatistics

env = gym.make("CartPole-v1")
env = RecordEpisodeStatistics(env, buffer_length=100)
Wrapper 作用
RecordEpisodeStatistics info["episode"] 含 return、length
NormalizeObservation 运行均值方差归一化
NormalizeReward 奖励缩放
FrameStack Atari 堆叠 4 帧
TimeLimit 步数上限(内置)

自定义 Wrapper 继承 gymnasium.Wrapper,重写 step / reset


五、向量化环境

1
2
3
4
5
6
7
8
9
10
from gymnasium.vector import SyncVectorEnv

def make_env():
return gym.make("CartPole-v1")

num_envs = 4
env = SyncVectorEnv([make_env for _ in range(num_envs)])
obs, _ = env.reset() # shape (num_envs, obs_dim)
actions = env.action_space.sample()
obs, rewards, term, trunc, infos = env.step(actions)
适用
SyncVectorEnv 同进程多环境,调试方便
AsyncVectorEnv 多进程,提高采样吞吐

PPO/A2C 常用 $N$ 个并行 env 收集 Rollout。


六、随机种子

1
2
3
4
5
6
7
def set_seed(seed: int):
import random, numpy as np, torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
env.reset(seed=seed)
env.action_space.seed(seed)

记录 seed 于 config,便于 RL-06 复现


七、与算法模块的接口

建议 Agent 基类最小接口:

1
2
3
4
5
6
7
class Agent:
def select_action(self, obs): ...
def store(self, obs, action, reward, next_obs, done): ...
def maybe_update(self): ...
def on_episode_end(self, ep_return, ep_len): ...
def save(self, path): ...
def load(self, path): ...

Off-Policy(DQN)在 store 后每步可 update;On-Policy(PPO)在 Rollout 满 $T$ 步后批量 update


八、小结

  • 统一 reset / step / done 语义后再写算法。
  • terminated vs truncated 影响 Bootstrap。
  • 下一篇:PyTorch 实现要点
-------------本文结束感谢您的阅读-------------