RL-05-03-结构-Replay-Buffer

← 上级:RL-05.专属数据结构 · 实现:RL-04-04-DQN实现

Replay Buffer 打破时序相关,支撑 DQN/DDPG/SAC 等 Off-Policy 算法。


一、deque 版(教学)

1
2
3
4
5
6
7
8
9
10
11
12
13
from collections import deque
import random

class ReplayBufferSimple:
def __init__(self, capacity):
self.buf = deque(maxlen=capacity)

def push(self, *args):
self.buf.append(args)

def sample(self, batch_size):
batch = random.sample(self.buf, batch_size)
return map(np.stack, zip(*batch))

二、预分配环形数组(高效)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class ReplayBuffer:
def __init__(self, capacity, obs_shape, act_dim=1):
self.cap = capacity
self.ptr = 0
self.size = 0
self.obs = np.zeros((capacity, *obs_shape), np.float32)
self.next_obs = np.zeros((capacity, *obs_shape), np.float32)
self.actions = np.zeros((capacity, act_dim), np.int64)
self.rewards = np.zeros(capacity, np.float32)
self.dones = np.zeros(capacity, np.float32)

def add(self, obs, action, reward, next_obs, done):
i = self.ptr
self.obs[i] = obs
self.next_obs[i] = next_obs
self.actions[i] = action
self.rewards[i] = reward
self.dones[i] = done
self.ptr = (self.ptr + 1) % self.cap
self.size = min(self.size + 1, self.cap)

def sample(self, batch_size):
idx = np.random.randint(0, self.size, batch_size)
return (
self.obs[idx], self.actions[idx], self.rewards[idx],
self.next_obs[idx], self.dones[idx],
)

FIFOptr 循环覆盖最旧数据。


三、使用注意

说明
warmup size >= batch 再训练
容量 太小遗忘早、太大占内存
dtype 与网络 float32 一致

四、n-step / 帧栈

  • n-step:存 $n$ 步累积 $R$ 与 $s_{t+n}$
  • FrameStackobs 通道维堆叠 4 帧

五、小结

-------------本文结束感谢您的阅读-------------