用 Rust 手写 Transformer —— Day 2:反向传播与softmax的理论基础

2026-04-27

预测很烂,Loss 很大

最简单的模型:

y_pred = w * x

初始 w = 0.5,输入 x = 10,真实答案 y_true = 100

y_pred = 0.5 * 10 = 5

预测是 5,真实是 100,相差甚远。

用平方误差量化Loss:

Loss = (y_pred - y_true)²
     = (5 - 100)²
     = 9025

w怎么调能让Loss减小?


数学给出答案: 看 Loss 对 w 的导数

导数 > 0 -> w 应该减小 导数 < 0 -> w 应该增大

Loss = (w*x - y_true)²

dLoss/dw = 2 * (w*x - y_true) * x
         = 2 * (5 - 100) * 10
         = -1900

导数是 -1900

w 应增大

用学习率 lr = 0.001 更新:

w_new = w - lr * dLoss/dw
      = 0.5 - 0.001 * (-1900)
      = 0.5 + 1.9
      = 2.4

一轮一轮跑

轮次wy_predLoss梯度
初始0.505.09025.0
第 1 轮2.4024.05776.0-1900
第 2 轮4.0640.63494.4-1520
第 3 轮5.4354.32088.5-1188
第 4 轮6.5465.41193.6-891.4
第 5 轮7.4374.3655.3-651.2
第 10 轮9.2692.654.8-148
第 15 轮9.8398.32.9-34
第 20 轮9.9799.70.09

w 从 0.5 一步步逼近 10(y_true/x = 100/10 = 10),Loss 从 9025 掉到接近 0。

这就是训练的全部本质:前向算 Loss,反向算梯度,更新 w,循环。

神经网络只是把这个套娃。


反向传播没有新数学

上面那个 dLoss/dw = -1900 怎么算出来的?

两步链式法则,仅此而已。

u = w*x - y_true,则 Loss = u²

dLoss/du = 2u = 2 * (5 - 100) = -190
du/dw    = x  = 10

dLoss/dw = dLoss/du * du/dw = -190 * 10 = -1900

链式法则就是高中微积分。反向传播只是给它起了个名字,然后在神经网络的计算图上系统地用一遍。

层数再深,也是同一件事:从 Loss 往回,一层一层用链式法则。


从标量到矩阵

现实中一次处理多个样本,每个样本多个特征,输出多个神经元——标量变矩阵。

X = [[1, 2],      # 样本0:特征1=1, 特征2=2
     [3, 4]]      # 样本1:特征1=3, 特征2=4

W = [[0.1, 0.5],  # 特征0 → 输出0权重=0.1, 输出1权重=0.5
     [0.2, 0.6]]  # 特征1 → 输出0权重=0.2, 输出1权重=0.6

Y_pred = X @ W

算出来:

Y_pred[0,0] = 1*0.1 + 2*0.2 = 0.5
Y_pred[0,1] = 1*0.5 + 2*0.6 = 1.7
Y_pred[1,0] = 3*0.1 + 4*0.2 = 1.1
Y_pred[1,1] = 3*0.5 + 4*0.6 = 3.9

真实标签:

Y_true = [[2.0, 3.0],
          [5.0, 7.0]]

Loss = sum((Y_pred - Y_true)²)

dout 是 Loss 对 Y_pred 每个元素的导数,跟标量时一样,d(u²)/du = 2u

dout = 2 * (Y_pred - Y_true)
     = [[-3.0, -2.6],
        [-7.8, -6.2]]

dW 怎么算

W[0,0] = 0.1 参与了哪些 Y 的计算?

Y_pred[0,0] = X[0,0]*W[0,0] + ...   # 样本0
Y_pred[1,0] = X[1,0]*W[0,0] + ...   # 样本1

链式法则,两个样本的贡献加起来:

dLoss/dW[0,0] = dout[0,0] * X[0,0] + dout[1,0] * X[1,0]
              = (-3.0)*1 + (-7.8)*3
              = -26.4

对所有元素这样做,整理成矩阵:

dW = X.T @ dout

形状:(2,2).T @ (2,2) = (2,2)——和 W 同形状,正确。

dX 怎么算

X[0,0] 参与了哪些 Y 的计算?

Y_pred[0,0] = X[0,0]*W[0,0] + ...   # 输出神经元0
Y_pred[0,1] = X[0,0]*W[0,1] + ...   # 输出神经元1
dLoss/dX[0,0] = dout[0,0]*W[0,0] + dout[0,1]*W[0,1]
              = (-3.0)*0.1 + (-2.6)*0.5
              = -1.6

整理成矩阵:

dX = dout @ W.T

矩阵公式不是魔法

dW = X.T @ doutdX = dout @ W.T——这两行公式不是背出来的。

每一个,都是"对某个元素写链式法则,然后发现可以整理成矩阵乘法"。

验证正确性最简单的方式:形状对得上就基本对了

dW: (C, B) @ (B, D) = (C, D)    ← 和 W 同形状 ✓
dX: (B, D) @ (D, C) = (B, C)    ← 和 X 同形状 ✓

Softmax:logits 变概率

矩阵乘法输出任意实数(logits),无法直接解释为概率:

logits = [2.1, -0.5, 3.8]

Softmax 把它变成合法概率分布(每项 > 0,加起来 = 1):

softmax(x[i]) = exp(x[i]) / Σ exp(x[j])

计算:

exp(2.1)  = 8.17
exp(-0.5) = 0.61
exp(3.8)  = 44.70
sum       = 53.48

probs = [0.153, 0.011, 0.836]   # 加起来 = 1.0 ✓

exp 的作用:把负数变正数,同时放大差距(3.8 比 2.1 大 1.7 倍,但 exp 之后大 5.5 倍)。


Softmax 反向传播

设 softmax 输出为 s,上层传来的梯度为 grad_out

grad_in[i] = s[i] * (grad_out[i] - Σ(grad_out * s))

Σ(grad_out * s) 是梯度的加权平均(权重是概率)。每个位置减去这个"基准线",再乘以自身概率。

为什么有这个基准线?softmax 是归一化操作,一个位置概率升高,其他位置必然降低——梯度必须体现这个相互制约。


Rust 实现

fn matmul_backward(
    dout: &Array3<f32>,   // (B, T, D)
    x: &Array3<f32>,      // (B, T, C)
    w: &Array2<f32>,      // (C, D)
) -> (Array3<f32>, Array2<f32>) {
    let (b, t, c) = x.dim();
    let d = w.shape()[1];

    let dx = dout.dot(&w.t());                                    // dX = dout @ W.T

    let x_2d   = x.view().into_shape((b * t, c)).unwrap();
    let dout_2d = dout.view().into_shape((b * t, d)).unwrap();
    let dw = x_2d.t().dot(&dout_2d);                             // dW = X.T @ dout

    (dx, dw)
}

pub fn softmax(x: &Array2<f32>) -> Array2<f32> {
    let max = x.map_axis(Axis(1), |r| r.fold(f32::NEG_INFINITY, |a, &b| a.max(b)));
    let exp = (x - &max.insert_axis(Axis(1))).mapv(f32::exp);
    &exp / &exp.sum_axis(Axis(1)).insert_axis(Axis(1))
}

pub fn softmax_backward(s: &Array2<f32>, grad_out: &Array2<f32>) -> Array2<f32> {
    let dot = (s * grad_out).sum_axis(Axis(1)).insert_axis(Axis(1));
    s * (grad_out - &dot)
}

从标量求导,到链式法则,到矩阵整理,到 ndarray 实现——每一步都是同一件事的不同规模。