【GDNet】关于gt的解释与说明
·
对于 g t g_t gt的形状
在 modeling_qwen3_5.py 的 Qwen3_5GatedDeltaNet.forward 中:
# in_proj_a: Linear(hidden_size -> num_v_heads)
a = self.in_proj_a(hidden_states) # shape: [batch, seq_len, num_v_heads]
# A_log: Parameter(num_v_heads,)
# dt_bias: Parameter(num_v_heads,)
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# g shape: [batch, seq_len, num_v_heads]
所以对于序列中的某个 token (t):
g t ∈ R H g_t \in \mathbb{R}^{H} gt∈RH
其中 (H) = num_v_heads(Qwen3.5-9B 中是 32)。
(g_t) 是一个向量,每个 head 有自己独立的衰减值。
这意味着什么
记忆更新公式中:
S ← exp ( g t ) ⋅ S + k t ⋅ Δ T S \leftarrow \exp(g_t) \cdot S + k_t \cdot \Delta^T S←exp(gt)⋅S+kt⋅ΔT
这里 (S) 的 shape 是 [batch, num_heads, d_k, d_v],而 (\exp(g_t)) 的 shape 是 [batch, num_heads]。广播后,每个 head 的记忆矩阵有自己独立的衰减率。
直觉上:
- 某个 head 的 (g_t) 很负 → (\exp(g_t) \approx 0) → 这个 head 几乎忘掉所有旧记忆,只保留当前 token 的信息
- 某个 head 的 (g_t) 接近 0 → (\exp(g_t) \approx 1) → 这个 head 完整保留旧记忆,只叠加增量
不同 head 可以学到不同的"记忆策略"——有的 head 专注短期记忆(衰减快),有的 head 维护长期记忆(衰减慢)。这和 Mamba 中 (A_t) 的作用是类似的。
为什么 (g_t) 恒为负数
注意计算公式:
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
A_log.exp()→ 恒正F.softplus(...)→ 恒正(softplus 的值域是 ((0, +\infty)))- 前面有负号 → (g_t) 恒为负数
所以 (\exp(g_t) \in (0, 1)),保证记忆矩阵是衰减的,不会爆炸。这是一个精心设计的数值稳定性保证。
更多推荐

所有评论(0)