RL-04-05-PPO实现

← 上级:RL-04.实现框架与实践 · 算法:RL-03-11-算法-PPO · Buffer:RL-05-05-结构-Rollout-Buffer

PPO 实现核心:Rollout 缓冲 → GAE → 多 epoch clip 更新。以下为 CartPole 最小可运行结构。


一、Actor-Critic 网络

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
import torch
import torch.nn as nn
from torch.distributions import Categorical

class ActorCritic(nn.Module):
def __init__(self, obs_dim, n_actions, hidden=64):
super().__init__()
self.shared = nn.Sequential(
nn.Linear(obs_dim, hidden), nn.Tanh(),
nn.Linear(hidden, hidden), nn.Tanh(),
)
self.pi = nn.Linear(hidden, n_actions)
self.v = nn.Linear(hidden, 1)

def forward(self, x):
h = self.shared(x)
return self.pi(h), self.v(h).squeeze(-1)

def act(self, obs):
logits, value = self.forward(obs)
dist = Categorical(logits=logits)
action = dist.sample()
return action, dist.log_prob(action), dist.entropy(), value

二、GAE 计算

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import numpy as np

def compute_gae(rewards, values, dones, gamma=0.99, lam=0.95):
T = len(rewards)
advantages = np.zeros(T, dtype=np.float32)
last_gae = 0.0
for t in reversed(range(T)):
next_non_terminal = 1.0 - dones[t]
next_value = values[t + 1] if t + 1 < T else 0.0
delta = rewards[t] + gamma * next_value * next_non_terminal - values[t]
last_gae = delta + gamma * lam * next_non_terminal * last_gae
advantages[t] = last_gae
returns = advantages + values
return advantages, returns

三、Rollout 收集

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class RolloutBuffer:
def __init__(self):
self.clear()

def clear(self):
self.obs, self.actions, self.log_probs = [], [], []
self.rewards, self.dones, self.values = [], [], []

def add(self, obs, action, log_prob, reward, done, value):
self.obs.append(obs)
self.actions.append(action)
self.log_probs.append(log_prob)
self.rewards.append(reward)
self.dones.append(float(done))
self.values.append(value)

def __len__(self):
return len(self.rewards)

四、PPO 更新

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch.nn.functional as F

def ppo_update(model, optimizer, buf, clip_eps=0.2, vf_coef=0.5, ent_coef=0.01,
epochs=10, batch_size=64, gamma=0.99, lam=0.95, device="cpu"):
obs = torch.as_tensor(np.array(buf.obs), dtype=torch.float32, device=device)
actions = torch.as_tensor(buf.actions, dtype=torch.int64, device=device)
old_log_probs = torch.as_tensor(buf.log_probs, dtype=torch.float32, device=device)
values = np.array(buf.values, dtype=np.float32)
adv, ret = compute_gae(buf.rewards, values, buf.dones, gamma, lam)
adv = torch.as_tensor(adv, dtype=torch.float32, device=device)
ret = torch.as_tensor(ret, dtype=torch.float32, device=device)
adv = (adv - adv.mean()) / (adv.std() + 1e-8)

n = len(buf)
idx = np.arange(n)
for _ in range(epochs):
np.random.shuffle(idx)
for start in range(0, n, batch_size):
mb = idx[start:start + batch_size]
logits, v = model(obs[mb])
dist = Categorical(logits=logits)
log_prob = dist.log_prob(actions[mb])
entropy = dist.entropy().mean()

ratio = torch.exp(log_prob - old_log_probs[mb])
surr1 = ratio * adv[mb]
surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * adv[mb]
pi_loss = -torch.min(surr1, surr2).mean()
vf_loss = F.mse_loss(v, ret[mb])
loss = pi_loss + vf_coef * vf_loss - ent_coef * entropy

optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()

五、训练循环

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import gymnasium as gym

def train_ppo(total_steps=80_000, rollout_len=2048):
env = gym.make("CartPole-v1")
obs_dim = env.observation_space.shape[0]
n_actions = env.action_space.n
device = torch.device("cpu")
model = ActorCritic(obs_dim, n_actions).to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
buf = RolloutBuffer()
obs, _ = env.reset()
global_step = 0

while global_step < total_steps:
buf.clear()
for _ in range(rollout_len):
o = torch.as_tensor(obs, dtype=torch.float32, device=device)
with torch.no_grad():
action, log_prob, _, value = model.act(o.unsqueeze(0))
action = action.item()
log_prob = log_prob.item()
value = value.item()
next_obs, reward, term, trunc, _ = env.step(action)
done = term or trunc
buf.add(obs, action, log_prob, reward, done, value)
obs = next_obs
global_step += 1
if done:
obs, _ = env.reset()
if global_step >= total_steps:
break
ppo_update(model, opt, buf, device=device)
print(f"step={global_step}, last_ep_reward~={sum(buf.rewards[-200:]):.0f}")

env.close()

六、要点

说明
On-Policy 每次 ppo_update清空 Rollout
优势标准化 稳定 clip 梯度
多 epoch 同一批数据重复利用 $K$ 次
连续动作 Actor 输出 Normal mean/std

七、小结

  • PPO = Rollout + GAE + clip surrogate + value/entropy loss
  • 下一篇:超参与调优
-------------本文结束感谢您的阅读-------------