Amazon AlphaZero: 神经网络架构深度解析
从棋盘特征到全局策略的完整信息流转
Stage 1: 原始棋盘与标准化 (Canonical Form)
一切始于游戏引擎。在 MCTS 搜索中,棋盘是一个基础的 8x8 整数矩阵。为了让神经网络更容易学习,我们必须将所有玩家的视角统一。
Canonical Board (8x8 int32)
数值编码: 1=己方(白), -1=敌方(黑), 2=障碍/箭
def getCanonicalForm(self, board, player):
if player == 1:
return np.rot90(board, 2)
return board
N-to-1 架构基石: 这些标准化的棋盘数据不会直接送入 GPU,而是通过高效的 IPC 机制汇聚到 Dispatcher,这是实现高吞吐量的第一步。
Next: 特征工程 >>
Stage 2: 7通道特征提取 (Feature Extraction)
在 GpuWorker 中,原始的整数矩阵被“翻译”成神经网络能理解的语言。我们提取了 7 个关键特征平面,显式地告诉网络棋盘上的局势。
Input Tensor: [Batch, 7, 8, 8]
board_7ch = np.zeros((B, 7, 8, 8), dtype=np.float32)
board_7ch[i] = teacher_bridge.compute_7ch_features(my, op, arr)
这个步骤将复杂的规则逻辑转化为几何特征,极大地降低了神经网络的学习难度。
Next: 深层骨干网络 >>
Stage 3: 深层 ResNet-40 骨干网络
数据进入 AmazonsPytorch.py。7通道的输入首先被投影到高维空间,然后通过一个深达 40 层的残差网络 (ResNet)。这是 AI 的“躯干”,负责理解复杂的棋形和局势。
Input Conv
7 → 256 Channels
升维投影
→
ResBlock x 20
ResNet-40 Backbone
Conv3x3 → BN → ReLU
Conv3x3 → BN → + → ReLU
class AmazonsNNet(nn.Module):
def __init__(self, game, args):
self.conv1 = nn.Conv2d(7, 256, kernel_size=3, ...)
self.res_layers = nn.ModuleList([ResBlock(256) for _ in range(20)])
def forward(self, s):
s = F.relu(self.bn1(self.conv1(s)))
for layer in self.res_layers:
s = layer(s)
return s
经过这一阶段,网络已经提取了丰富的局部特征,但它对棋盘的理解仍然局限于卷积核覆盖的区域。
Next: 关键技术 - 全局注意力 >>
⚡ Stage 4: 全局自注意力机制 (Global Self-Attention)
这是网络中最关键的组件之一。 传统的卷积网络 (CNN) 受限于局部感受野,难以捕捉远距离的依赖关系。而 Self-Attention 允许网络在一步之内建立棋盘上任意两个位置之间的联系。
交互演示:捕捉远距离依赖 (Hover to Interact)
将鼠标悬停在网格上,体验一个位置 (Query ) 如何“瞬间”关注并聚合棋盘上其他关键位置 (Key/Value ) 的信息。
示意图:鼠标所在位置的特征直接关注并聚合了棋盘其他区域的信息,打破了卷积的局部限制。
class SelfAttentionBlock(nn.Module):
def forward(self, x):
proj_query = self.query(x).view(B, -1, N).permute(0, 2, 1)
proj_key = self.key(x).view(B, -1, N)
energy = torch.bmm(proj_query, proj_key)
attention = F.softmax(energy, dim=-1)
proj_value = self.value(x).view(B, -1, N)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
return self.gamma * out + x
通过这一步,特征图中的每一个点都包含了全局的上下文信息,为接下来的复杂决策奠定了基础。
Next: 动态策略头 >>
Stage 5: 动态特征更新头 (Dynamic Feature Update Head)
这是针对亚马逊棋规则设计的特殊结构。由于一步棋包含“移动”和“射箭”两个动作,我们采用了一种因果推理 机制:网络在特征层面“想象”棋子移动后的样子,再据此决定射箭位置。
A. Move Decision
Query: Source
Key: Destination
→
B. Latent Simulation
Update Gate
将目的地(Dst)信息 注入全局特征
→
C. Arrow Decision
Query: Shooter (New Pos)
Key: Target
输出 1: P(Src, Dst)
输出 2: P(Arr | Dst)
move_logits = torch.bmm(src_q, dst_k) ...
combined = torch.cat([x, dst_k_map], dim=1)
x_next = self.update_gate(combined) + x
arr_q = self.conv_arrow_query(x_next)...
arrow_logits = torch.bmm(arr_q, arr_k) ...
这种设计确保了射箭的决策是基于移动之后的最新局势做出的,符合游戏规则的因果逻辑。
Next: 最终输出 >>
Stage 6: 最终输出与 MCTS 集成
网络最终输出三个核心张量,它们被回传给 Dispatcher,最终由 MCTS 用于指导搜索和决策。
1. Policy Move (Logits)
P(Source, Destination) 的联合概率分布
Shape: [Batch, 4096]
2. Policy Arrow (Logits)
给定目的地 Dst 后的条件射箭概率 P(Arrow | Dst)
Shape: [Batch, 64, 64]
3. Value (v)
对当前局面胜率的标量估计 (-1 到 1)
Shape: [Batch, 1]
在 MCTS 搜索中,这两个概率矩阵被结合起来计算一个完整动作的先验概率:
prob = prob_move_part * prob_arrow_part
重新开始演示