对于 g t g_t gt的形状

modeling_qwen3_5.pyQwen3_5GatedDeltaNet.forward 中:

# in_proj_a: Linear(hidden_size -> num_v_heads)
a = self.in_proj_a(hidden_states)          # shape: [batch, seq_len, num_v_heads]

# A_log: Parameter(num_v_heads,)
# dt_bias: Parameter(num_v_heads,)
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
# g shape: [batch, seq_len, num_v_heads]

所以对于序列中的某个 token (t):

g t ∈ R H g_t \in \mathbb{R}^{H} gtRH

其中 (H) = num_v_heads(Qwen3.5-9B 中是 32)。

(g_t) 是一个向量,每个 head 有自己独立的衰减值。


这意味着什么

记忆更新公式中:

S ← exp ⁡ ( g t ) ⋅ S + k t ⋅ Δ T S \leftarrow \exp(g_t) \cdot S + k_t \cdot \Delta^T Sexp(gt)S+ktΔT

这里 (S) 的 shape 是 [batch, num_heads, d_k, d_v],而 (\exp(g_t)) 的 shape 是 [batch, num_heads]。广播后,每个 head 的记忆矩阵有自己独立的衰减率

直觉上:

  • 某个 head 的 (g_t) 很负 → (\exp(g_t) \approx 0) → 这个 head 几乎忘掉所有旧记忆,只保留当前 token 的信息
  • 某个 head 的 (g_t) 接近 0 → (\exp(g_t) \approx 1) → 这个 head 完整保留旧记忆,只叠加增量

不同 head 可以学到不同的"记忆策略"——有的 head 专注短期记忆(衰减快),有的 head 维护长期记忆(衰减慢)。这和 Mamba 中 (A_t) 的作用是类似的。


为什么 (g_t) 恒为负数

注意计算公式:

g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
  • A_log.exp() → 恒正
  • F.softplus(...) → 恒正(softplus 的值域是 ((0, +\infty)))
  • 前面有负号 → (g_t) 恒为负数

所以 (\exp(g_t) \in (0, 1)),保证记忆矩阵是衰减的,不会爆炸。这是一个精心设计的数值稳定性保证。

Logo

这里是“一人公司”的成长家园。我们提供从产品曝光、技术变现到法律财税的全栈内容,并连接云服务、办公空间等稀缺资源,助你专注创造,无忧运营。

更多推荐