llm.c 源码阅读

前段时间刷到 Karpathy 大神的推, 赶紧去 star 了这个 repo, 但是因为一直忙着写 acore, 没来得及详细阅读代码, 最近闲下来了准备写篇博客记录一下, 顺便做一个这个 repo 的源码分析.


下面按照 main 函数的顺序讲, 一上来先是从 train.py 下载的 gpt2.bin 里面加载模型权重.

    // build the GPT-2 model from a checkpoint
    GPT2 model;
    gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");

这里定义了一个 GPT2 类型的对象, 想必应该就是模型的所有信息了, 里面的内容是

typedef struct {
    GPT2Config config;
    // the weights of the model, and their sizes
    ParameterTensors params;
    size_t param_sizes[NUM_PARAMETER_TENSORS];
    float* params_memory;
    int num_parameters;
    // gradients of the weights
    ParameterTensors grads;
    float* grads_memory;
    // buffers for the AdamW optimizer
    float* m_memory;
    float* v_memory;
    // the activations of the model, and their sizes
    ActivationTensors acts;
    size_t act_sizes[NUM_ACTIVATION_TENSORS];
    float* acts_memory;
    int num_activations;
    // gradients of the activations
    ActivationTensors grads_acts;
    float* grads_acts_memory;
    // other run state configuration
    int batch_size; // the batch size (B) of current forward pass
    int seq_len; // the sequence length (T) of current forward pass
    int* inputs; // the input tokens for the current forward pass
    int* targets; // the target tokens for the current forward pass
    float mean_loss; // after a forward pass with targets, will be populated with the mean loss
} GPT2

这里面有模型的配置, 权重, 梯度, 激活值等等, 以及一些运行时的配置, 比如 batch size, sequence length, input tokens, target tokens 等等. 在这里先不深究每一项的含义 (因为跳转过去也还看不懂). 随后 build_from_checkpoint 读取模型的各类参数.

fread(model_header, sizeof(int), 256, model_file);
model->config.max_seq_len = maxT = model_header[2];
model->config.vocab_size = V = model_header[3];
model->config.num_layers = L = model_header[4];
model->config.num_heads = NH = model_header[5];
model->config.channels = C = model_header[6];

这里面每一行依次是 max sequence length, vocabulary size, 模型的 layer 数量, multi-head attention 的 head 个数与 channel 数量 (embedding size). 通过这些参数可以计算出来模型中每一层的参数个数.

    // allocate space for all the parameters and read them in
model->param_sizes[0] = V * C;
model->param_sizes[1] = maxT * C;
model->param_sizes[2] = L * C;
model->param_sizes[3] = L * C;
model->param_sizes[4] = L * (3 * C) * C;
model->param_sizes[5] = L * (3 * C);
model->param_sizes[6] = L * C * C;
model->param_sizes[7] = L * C;
model->param_sizes[8] = L * C;
model->param_sizes[9] = L * C;
model->param_sizes[10] = L * (4 * C) * C;
model->param_sizes[11] = L * (4 * C);
model->param_sizes[12] = L * C * (4 * C);
model->param_sizes[13] = L * C;
model->param_sizes[14] = C;
model->param_sizes[15] = C;

直接看这些参数可能稍微有点难理解, 但是如果结合后面 gpt2_forward 的代码一起看就知道每一行是干嘛的了.

前向传播

第一行应该是 token embedding 的参数, vocab size 个单词 * embedding size, 第二行是 position embedding 的参数, sequence length 个位置 * embedding size, 每一个 input token 分别过词表和位置编码, 得到 embedding vector 输入到 transformer 模块. 在推理中对应的代码是

void encoder_forward(float* out,
                   int* inp, float* wte, float* wpe,
                   int B, int T, int C) {
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            // seek to the output position in out[b,t,:]
            float* out_bt = out + b * T * C + t * C;
            // get the index of the token at inp[b, t]
            int ix = inp[b * T + t];
            // seek to the position in wte corresponding to the token
            float* wte_ix = wte + ix * C;
            // seek to the position in wpe corresponding to the position
            float* wpe_t = wpe + t * C;
            // add the two vectors and store the result in out[b,t,:]
            for (int i = 0; i < C; i++) {
                out_bt[i] = wte_ix[i] + wpe_t[i];
            }
        }
    }
}

就是对于 batch 中的每一个 input 的每一个 token, 分别取出词表中 token 对应的 \(C\) 维 embedding 和位置编码, 然后每一个维度相加得到最终的 embedding vector.

接下来过若干个 transformer block, 每个 block 的结构是这样的:

// now do the forward pass
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);
matmul_forward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
attention_forward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH);
matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);
residual_forward(l_residual2, residual, l_attproj, B*T*C);
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);

首先先过一个 layer norm, 回忆 layer norm 的公式

\[out = W\cdot\frac{x - \mu}{\sigma} + b\]

结合代码

for (int t = 0; t < T; t++) {
    // seek to the input position inp[b,t,:]
    float* x = inp + b * T * C + t * C;
    // calculate the mean
    float m = 0.0f;
    for (int i = 0; i < C; i++) {
        m += x[i];
    }
    m = m/C;
    // calculate the variance (without any bias correction)
    float v = 0.0f;
    for (int i = 0; i < C; i++) {
        float xshift = x[i] - m;
        v += xshift * xshift;
    }
    v = v/C;
    // calculate the rstd
    float s = 1.0f / sqrtf(v + eps);
    // seek to the output position in out[b,t,:]
    float* out_bt = out + b * T * C + t * C;
    for (int i = 0; i < C; i++) {
        float n = (s * (x[i] - m)); // normalized output
        float o = n * weight[i] + bias[i]; // scale and shift it
        out_bt[i] = o; // write
    }
    // cache the mean and rstd for the backward pass later
    mean[b * T + t] = m;
    rstd[b * T + t] = s;
}

可以看出来是直接对每个 token 的不同 channel 进行 norm, 而不是对一个 input 的所有 token.

随后是 attention 部分, 首先是三次 matmul 计算出 input 每一个 token 对应的 \(Q, K, V\)

for (int t = 0; t < T; t++) {
    float* out_bt = out + b * T * OC + t * OC;
    float* inp_bt = inp + b * T * C + t * C;
    for (int o = 0; o < OC; o++) {
        float val = (bias != NULL) ? bias[o] : 0.0f;
        float* wrow = weight + o*C;
        for (int i = 0; i < C; i++) {
            val += inp_bt[i] * wrow[i];
        }
        out_bt[o] = val;
    }
}

注意这里通过 omp 实现了多线程计算提升矩阵乘法性能, 其中 input_bt 取出了第 \(t\) 个 token 的 \(C\) 维 embedding vector, wrow 取出了 \(3C \times C\) 维的权重矩阵 (\(W_Q, W_K, W_V\) 拼在一起)的第 \(o\) 列, 随后相乘得到第 \(o\) 个 output channel. 最终从 \((B, T, C)\) 的输入得到了 \((B, T, 3C)\) 的输出.

然后过 multi-head self-attention, 每个 head 负责处理 \(C/NH\) 维的输入, 通过多个 head 的并行计算得到最终的 attention 结果.

for (int t = 0; t < T; t++) {
    for (int h = 0; h < NH; h++) {
        float* query_t = inp + b * T * C3 + t * C3 + h * hs;
        float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
        float* att_bth = att + b*NH*T*T + h*T*T + t*T
        // pass 1: calculate query dot key and maxval
        float maxval = -10000.0f; // TODO something better
        for (int t2 = 0; t2 <= t; t2++) {
            float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's ke
            // (query_t) dot (key_t2)
            float val = 0.0f;
            for (int i = 0; i < hs; i++) {
                val += query_t[i] * key_t2[i];
            }
            val *= scale;
            if (val > maxval) {
                maxval = val;
            }
            preatt_bth[t2] = val;
        }
    }
}

首先是计算 \(\frac{Q \cdot K}{\sqrt{d_k}}\) 的值, 这里是对于每一个 token \(t\) 计算其与之前所有 token 的 dot product, 得到 \((B, NH, T, T)\) 的 pre attention 下三角矩阵, 同时记录每一行的最大值, 便于 softmax 的计算.

然后过 softmax, 得到 \(\operatorname{softmax}({\frac{Q \cdot K}{\sqrt{d_k}}})\), 其中 \(\operatorname{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}\)

float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
    float expv = expf(preatt_bth[t2] - maxval);
    expsum += expv;
    att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum
// pass 3: normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
    if (t2 <= t) {
        att_bth[t2] *= expsum_inv;
    } else {
        // causal attention mask. not strictly necessary to set to zero here
        // only doing this explicitly for debugging and checking to PyTorch
        att_bth[t2] = 0.0f;
    }
}

可以看到依旧是对每一行的 pre attention 矩阵进行 softmax, 得到的矩阵与原矩阵的形状相同, 但是每一行的和为 1.

最后计算 \(V\) 的加权和, 得到最终的 attention 结果.

// pass 4: accumulate weighted values into the output of attention
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
    float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
    float att_btht2 = att_bth[t2];
    for (int i = 0; i < hs; i++) {
        out_bth[i] += att_btht2 * value_t2[i];
    }
}

其中 out_bth 指向第 t 个 token 的第 h 个 head 输出的起始指针, 随后对于每一个维度和 token 进行加权求和, 得到第 t 个 token 在这些 channel 上的 attention 输出. 把全部 head 的输出拼在一起, 再过一个 matmul, 就得到了 \((B, T, C)\) 的

\[\operatorname{Attention}(Q, K, V) = \operatorname{Concat}(\operatorname{softmax}(\frac{Q \cdot K}{\sqrt{d_k}}) \cdot V) \cdot W_O\]

然后顺序过 residual connection, layer norm 和两层的 feed forward network, 最终得到了一个 transformer block 的 \((B, T, C)\) 的输出.

residual_forward(l_residual2, residual, l_attproj, B*T*C);
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C)

在 \(L\) 层 transformer block 之后, 过一个 layer norm 得到 \((B, T, C)\) 的 logits, 再乘上 \(vocab\) 的转置过 softmax 得到每个 token 的 \((B, T, V)\) 的概率分布.

residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp);
softmax_forward(acts.probs, acts.logits, B, T, V, Vp);

至此前向传播结束, 得到了每一个 token 的概率分布, 随后通过 cross entropy \(H(p,q) = -\sum_x p(x) \log q(x)\) 计算 loss.

void crossentropy_forward(float* losses,
                          float* probs, int* targets,
                          int B, int T, int Vp) {
    // output: losses is (B,T) of the individual losses at each position
    // input: probs are (B,T,Vp) of the probabilities
    // input: targets is (B,T) of integers giving the correct index in logits
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            // loss = -log(probs[target])
            float* probs_bt = probs + b * T * Vp + t * Vp;
            int ix = targets[b * T + t];
            losses[b * T + t] = -logf(probs_bt[ix]);
        }
    }
}

其实就是得到了每一个位置上 \(-\log p[correct]\) 的值, 最后返回一个 \((B, T)\) 的 loss 矩阵与其平均值.

反向传播

首先先初始化 loss 为 $$\frac{1}{B\times T}$, 然后开始对每一个 token 进行反向传播. 首先过 cross entropy 的反向传播.

void crossentropy_softmax_backward(float* dlogits,
                           float* dlosses, float* probs, int* targets,
                           int B, int T, int V, int Vp) {
    // backwards through both softmax and crossentropy
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            float* dlogits_bt = dlogits + b * T * Vp + t * Vp;
            float* probs_bt = probs + b * T * Vp + t * Vp;
            float dloss = dlosses[b * T + t];
            int ix = targets[b * T + t];
            // note we only loop to V, leaving the padded dimensions
            // of dlogits untouched, so gradient there stays at zero
            for (int i = 0; i < V; i++) {
                float p = probs_bt[i];
                float indicator = i == ix ? 1.0f : 0.0f;
                dlogits_bt[i] += (p - indicator) * dloss;
            }
        }
    }
}

交叉熵对 \(p\) 的梯度为

\[\frac{\partial y}{\partial p_i} = \frac{\partial y}{\partial loss} \frac{\partial loss}{\partial p_i} = -\log q_i \cdot loss\]

注意这里的梯度回传只考虑词表中的 $V$ 个正常 token, 不考虑 padding token.

然后是 matmul 的反向传播

// backward into inp first, parallelize over B,T
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
    for (int t = 0; t < T; t++) {
        float* dout_bt = dout + b * T * OC + t * OC;
        float* dinp_bt = dinp + b * T * C + t * C;
        for (int o = 0; o < OC; o++) {
            float* wrow = weight + o*C;
            float d = dout_bt[o];
            for (int i = 0; i < C; i++) {
                dinp_bt[i] += wrow[i] * d;
            }
        }
    }
}

第一步先计算下一步的回传梯度, 也即

\[\frac{\partial y}{\partial x_i} = \sum_j \frac{\partial y}{\partial y^{out}_j} \frac{\partial y^{out}_j}{\partial x_i} = \left(\frac{\partial y}{\partial y^{out}} \frac{\partial y^{out}}{\partial x}\right)_i\]

其中 \(\frac{\partial y}{\partial y^{out}}\in \mathbb{R}^{B\times T\times 1\times OC}\), \(\frac{\partial y^{out}}{\partial x}\in \mathbb{R}^{B\times T\times OC\times C}\), 最终得到 \(\frac{\partial y}{\partial x}\in \mathbb{R}^{B\times T\times C}\).

第二步计算对权重与偏置的梯度

#pragma omp parallel for
for (int o = 0; o < OC; o++) {
    for (int b = 0; b < B; b++) {
        for (int t = 0; t < T; t++) {
            float* dout_bt = dout + b * T * OC + t * OC;
            float* inp_bt = inp + b * T * C + t * C;
            float* dwrow = dweight + o*C;
            float d = dout_bt[o];
            if (dbias != NULL) { dbias[o] += d; }
            for (int i = 0; i < C; i++) {
                dwrow[i] += inp_bt[i] * d;
            }
        }
    }
}

同上, 显然有

\[\frac{\partial y}{\partial W} = \sum_{b,t} \frac{\partial y}{\partial y^{out}}\cdot \frac{\partial y^{out}}{\partial W} = \sum_{b,t} x^T_{b,t}\frac{\partial y}{\partial y^{out}}\]

其他的反向传播过程类似, 这里就不再赘述了.

CUDA 代码分析

最后再来看一下原始版本的 CUDA 代码, 这里做的优化可能相对少一点, 但是还是有一些值得学习的地方.

首先是 encoder_forward_kernel, 这个函数的实现非常简单, 就是简单地为每一个 output channel 分配一个 thread 来找对应的 embedding, 这样的实现会导致 kernel 的计算访存比很低, 所以性能增益应该不会很高.

__global__ void encoder_forward_kernel2(float* out,
                               int* inp, float* wte, float* wpe,
                               int B, int T, int C) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    int N = B * T * C;

    if (idx < N) {
        int bt = idx / C;
        int b = bt / T;
        int t = bt % T;
        int c = idx % C;

        int ix = inp[b * T + t];

        float* out_btc = out + b * T * C + t * C + c;
        float* wte_ix = wte + ix * C + c;
        float* wpe_tc = wpe + t * C + c;
        *out_btc = *wte_ix + *wpe_tc;
    }
}

不妨再跟最新版本对比一下

x128 packed_out;
x128 wte128 = load128cs(wte_ix);
x128 wpe128 = load128cs(wpe_tc);
for (int k = 0; k < x128::size; k++) {
    packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]);
}
store128(out_btc, packed_out);

可以看到是把多个 channel packed 在一起, 从而提升了计算效率.

随后是 layer norm 的实现, layer_norm_forward 里面调用了 mean_kernel, rstd_kernelnormalization_kernel 三个 kernel, 实现分别是这样的.

_global__ void mean_kernel(float* mean, float* inp, int N, int C, int block_size) {
    extern __shared__ float shared[];
    int idx = blockIdx.x; // range [0, B*T)
    int tid = threadIdx.x; // range [0, block_size)
    float* x = inp + idx * C;
    // thread coarsening
    float sum = 0.0f;
    for (int i = tid; i < C; i += block_size) {
        sum += x[i];
    }
    shared[tid] = sum;
    __syncthreads();
    // reductions
    for (int stride = block_size / 2; stride >= 1; stride /= 2) {
        __syncthreads();
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
    }
    // write the final result (at thread 0) to global memory
    if (tid == 0) {
        mean[idx] = shared[0] / C;
    }
}

这里的重点在于 reduction 的二分实现, 每一次迭代都干掉一半的 thread, 由他的一半来接管他的 sum, 最后归于到 thread 0 上完成 reduction. 在新版的实现里就直接调了内置函数 warpReduceSum 来完成 reduction.

__global__ void rstd_kernel(float* rstd, float* inp, float* mean, int N, int C, int block_size) {
    extern __shared__ float shared[];
    int idx = blockIdx.x; // range [0, B*T)
    int tid = threadIdx.x; // range [0, block_size)
    float* x = inp + idx * C;
    float m = mean[idx];
    // thread coarsening
    float sum = 0.0f;
    for (int i = tid; i < C; i += block_size) {
        float diff = x[i] - m;
        sum += diff * diff;
    }
    shared[tid] = sum;
    __syncthreads();
    // reductions
    for (int stride = block_size / 2; stride >= 1; stride /= 2) {
        __syncthreads();
        if (tid < stride) {
            shared[tid] += shared[tid + stride];
        }
    }
    // write the final result (at thread 0) to global memory
    if (tid == 0) {
        rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f);
    }
}

rstd kernel 和上面的 mean kernel 几乎可以说没什么区别, 所以显然是可以做 fusion 的. 果不其然, 再去看新版本 kernel 就发现已经把三个 ln 和 residual fused 到一起了. normalization 也是一样, 这里就不多讲了.

下一步是 matmul, 这个应该猜都能猜到会直接用 cuBLAS 来实现