llm.c 源码阅读
前段时间刷到 Karpathy 大神的推, 赶紧去 star 了这个 repo, 但是因为一直忙着写 acore, 没来得及详细阅读代码, 最近闲下来了准备写篇博客记录一下, 顺便做一个这个 repo 的源码分析.
下面按照 main
函数的顺序讲, 一上来先是从 train.py
下载的 gpt2.bin
里面加载模型权重.
// build the GPT-2 model from a checkpoint
GPT2 model;
gpt2_build_from_checkpoint(&model, "gpt2_124M.bin");
这里定义了一个 GPT2
类型的对象, 想必应该就是模型的所有信息了, 里面的内容是
typedef struct {
GPT2Config config;
// the weights of the model, and their sizes
ParameterTensors params;
size_t param_sizes[NUM_PARAMETER_TENSORS];
float* params_memory;
int num_parameters;
// gradients of the weights
ParameterTensors grads;
float* grads_memory;
// buffers for the AdamW optimizer
float* m_memory;
float* v_memory;
// the activations of the model, and their sizes
ActivationTensors acts;
size_t act_sizes[NUM_ACTIVATION_TENSORS];
float* acts_memory;
int num_activations;
// gradients of the activations
ActivationTensors grads_acts;
float* grads_acts_memory;
// other run state configuration
int batch_size; // the batch size (B) of current forward pass
int seq_len; // the sequence length (T) of current forward pass
int* inputs; // the input tokens for the current forward pass
int* targets; // the target tokens for the current forward pass
float mean_loss; // after a forward pass with targets, will be populated with the mean loss
} GPT2
这里面有模型的配置, 权重, 梯度, 激活值等等, 以及一些运行时的配置, 比如 batch size, sequence length, input tokens, target tokens 等等. 在这里先不深究每一项的含义 (因为跳转过去也还看不懂). 随后 build_from_checkpoint
读取模型的各类参数.
fread(model_header, sizeof(int), 256, model_file);
model->config.max_seq_len = maxT = model_header[2];
model->config.vocab_size = V = model_header[3];
model->config.num_layers = L = model_header[4];
model->config.num_heads = NH = model_header[5];
model->config.channels = C = model_header[6];
这里面每一行依次是 max sequence length, vocabulary size, 模型的 layer 数量, multi-head attention 的 head 个数与 channel 数量 (embedding size). 通过这些参数可以计算出来模型中每一层的参数个数.
// allocate space for all the parameters and read them in
model->param_sizes[0] = V * C;
model->param_sizes[1] = maxT * C;
model->param_sizes[2] = L * C;
model->param_sizes[3] = L * C;
model->param_sizes[4] = L * (3 * C) * C;
model->param_sizes[5] = L * (3 * C);
model->param_sizes[6] = L * C * C;
model->param_sizes[7] = L * C;
model->param_sizes[8] = L * C;
model->param_sizes[9] = L * C;
model->param_sizes[10] = L * (4 * C) * C;
model->param_sizes[11] = L * (4 * C);
model->param_sizes[12] = L * C * (4 * C);
model->param_sizes[13] = L * C;
model->param_sizes[14] = C;
model->param_sizes[15] = C;
直接看这些参数可能稍微有点难理解, 但是如果结合后面 gpt2_forward
的代码一起看就知道每一行是干嘛的了.
前向传播
第一行应该是 token embedding 的参数, vocab size 个单词 * embedding size, 第二行是 position embedding 的参数, sequence length 个位置 * embedding size, 每一个 input token 分别过词表和位置编码, 得到 embedding vector 输入到 transformer 模块. 在推理中对应的代码是
void encoder_forward(float* out,
int* inp, float* wte, float* wpe,
int B, int T, int C) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// seek to the output position in out[b,t,:]
float* out_bt = out + b * T * C + t * C;
// get the index of the token at inp[b, t]
int ix = inp[b * T + t];
// seek to the position in wte corresponding to the token
float* wte_ix = wte + ix * C;
// seek to the position in wpe corresponding to the position
float* wpe_t = wpe + t * C;
// add the two vectors and store the result in out[b,t,:]
for (int i = 0; i < C; i++) {
out_bt[i] = wte_ix[i] + wpe_t[i];
}
}
}
}
就是对于 batch 中的每一个 input 的每一个 token, 分别取出词表中 token 对应的 \(C\) 维 embedding 和位置编码, 然后每一个维度相加得到最终的 embedding vector.
接下来过若干个 transformer block, 每个 block 的结构是这样的:
// now do the forward pass
layernorm_forward(l_ln1, l_ln1_mean, l_ln1_rstd, residual, l_ln1w, l_ln1b, B, T, C);
matmul_forward(l_qkv, l_ln1, l_qkvw, l_qkvb, B, T, C, 3*C);
attention_forward(l_atty, l_preatt, l_att, l_qkv, B, T, C, NH);
matmul_forward(l_attproj, l_atty, l_attprojw, l_attprojb, B, T, C, C);
residual_forward(l_residual2, residual, l_attproj, B*T*C);
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C);
首先先过一个 layer norm, 回忆 layer norm 的公式
\[out = W\cdot\frac{x - \mu}{\sigma} + b\]结合代码
for (int t = 0; t < T; t++) {
// seek to the input position inp[b,t,:]
float* x = inp + b * T * C + t * C;
// calculate the mean
float m = 0.0f;
for (int i = 0; i < C; i++) {
m += x[i];
}
m = m/C;
// calculate the variance (without any bias correction)
float v = 0.0f;
for (int i = 0; i < C; i++) {
float xshift = x[i] - m;
v += xshift * xshift;
}
v = v/C;
// calculate the rstd
float s = 1.0f / sqrtf(v + eps);
// seek to the output position in out[b,t,:]
float* out_bt = out + b * T * C + t * C;
for (int i = 0; i < C; i++) {
float n = (s * (x[i] - m)); // normalized output
float o = n * weight[i] + bias[i]; // scale and shift it
out_bt[i] = o; // write
}
// cache the mean and rstd for the backward pass later
mean[b * T + t] = m;
rstd[b * T + t] = s;
}
可以看出来是直接对每个 token 的不同 channel 进行 norm, 而不是对一个 input 的所有 token.
随后是 attention 部分, 首先是三次 matmul 计算出 input 每一个 token 对应的 \(Q, K, V\)
for (int t = 0; t < T; t++) {
float* out_bt = out + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? bias[o] : 0.0f;
float* wrow = weight + o*C;
for (int i = 0; i < C; i++) {
val += inp_bt[i] * wrow[i];
}
out_bt[o] = val;
}
}
注意这里通过 omp 实现了多线程计算提升矩阵乘法性能, 其中 input_bt
取出了第 \(t\) 个 token 的 \(C\) 维 embedding vector, wrow
取出了 \(3C \times C\) 维的权重矩阵 (\(W_Q, W_K, W_V\) 拼在一起)的第 \(o\) 列, 随后相乘得到第 \(o\) 个 output channel. 最终从 \((B, T, C)\) 的输入得到了 \((B, T, 3C)\) 的输出.
然后过 multi-head self-attention, 每个 head 负责处理 \(C/NH\) 维的输入, 通过多个 head 的并行计算得到最终的 attention 结果.
for (int t = 0; t < T; t++) {
for (int h = 0; h < NH; h++) {
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T
// pass 1: calculate query dot key and maxval
float maxval = -10000.0f; // TODO something better
for (int t2 = 0; t2 <= t; t2++) {
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's ke
// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= scale;
if (val > maxval) {
maxval = val;
}
preatt_bth[t2] = val;
}
}
}
首先是计算 \(\frac{Q \cdot K}{\sqrt{d_k}}\) 的值, 这里是对于每一个 token \(t\) 计算其与之前所有 token 的 dot product, 得到 \((B, NH, T, T)\) 的 pre attention 下三角矩阵, 同时记录每一行的最大值, 便于 softmax 的计算.
然后过 softmax, 得到 \(\operatorname{softmax}({\frac{Q \cdot K}{\sqrt{d_k}}})\), 其中 \(\operatorname{softmax}(x)_i = \frac{e^{x_i}}{\sum_j e^{x_j}}\)
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum
// pass 3: normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// causal attention mask. not strictly necessary to set to zero here
// only doing this explicitly for debugging and checking to PyTorch
att_bth[t2] = 0.0f;
}
}
可以看到依旧是对每一行的 pre attention 矩阵进行 softmax, 得到的矩阵与原矩阵的形状相同, 但是每一行的和为 1.
最后计算 \(V\) 的加权和, 得到最终的 attention 结果.
// pass 4: accumulate weighted values into the output of attention
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {
out_bth[i] += att_btht2 * value_t2[i];
}
}
其中 out_bth
指向第 t
个 token 的第 h
个 head 输出的起始指针, 随后对于每一个维度和 token 进行加权求和, 得到第 t
个 token 在这些 channel 上的 attention 输出. 把全部 head 的输出拼在一起, 再过一个 matmul, 就得到了 \((B, T, C)\) 的
然后顺序过 residual connection, layer norm 和两层的 feed forward network, 最终得到了一个 transformer block 的 \((B, T, C)\) 的输出.
residual_forward(l_residual2, residual, l_attproj, B*T*C);
layernorm_forward(l_ln2, l_ln2_mean, l_ln2_rstd, l_residual2, l_ln2w, l_ln2b, B, T, C);
matmul_forward(l_fch, l_ln2, l_fcw, l_fcb, B, T, C, 4*C);
gelu_forward(l_fch_gelu, l_fch, B*T*4*C);
matmul_forward(l_fcproj, l_fch_gelu, l_fcprojw, l_fcprojb, B, T, 4*C, C);
residual_forward(l_residual3, l_residual2, l_fcproj, B*T*C)
在 \(L\) 层 transformer block 之后, 过一个 layer norm 得到 \((B, T, C)\) 的 logits, 再乘上 \(vocab\) 的转置过 softmax 得到每个 token 的 \((B, T, V)\) 的概率分布.
residual = acts.residual3 + (L-1) * B * T * C; // last residual is in residual3
layernorm_forward(acts.lnf, acts.lnf_mean, acts.lnf_rstd, residual, params.lnfw, params.lnfb, B, T, C);
matmul_forward(acts.logits, acts.lnf, params.wte, NULL, B, T, C, Vp);
softmax_forward(acts.probs, acts.logits, B, T, V, Vp);
至此前向传播结束, 得到了每一个 token 的概率分布, 随后通过 cross entropy \(H(p,q) = -\sum_x p(x) \log q(x)\) 计算 loss.
void crossentropy_forward(float* losses,
float* probs, int* targets,
int B, int T, int Vp) {
// output: losses is (B,T) of the individual losses at each position
// input: probs are (B,T,Vp) of the probabilities
// input: targets is (B,T) of integers giving the correct index in logits
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// loss = -log(probs[target])
float* probs_bt = probs + b * T * Vp + t * Vp;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
}
}
其实就是得到了每一个位置上 \(-\log p[correct]\) 的值, 最后返回一个 \((B, T)\) 的 loss 矩阵与其平均值.
反向传播
首先先初始化 loss 为 $$\frac{1}{B\times T}$, 然后开始对每一个 token 进行反向传播. 首先过 cross entropy 的反向传播.
void crossentropy_softmax_backward(float* dlogits,
float* dlosses, float* probs, int* targets,
int B, int T, int V, int Vp) {
// backwards through both softmax and crossentropy
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dlogits_bt = dlogits + b * T * Vp + t * Vp;
float* probs_bt = probs + b * T * Vp + t * Vp;
float dloss = dlosses[b * T + t];
int ix = targets[b * T + t];
// note we only loop to V, leaving the padded dimensions
// of dlogits untouched, so gradient there stays at zero
for (int i = 0; i < V; i++) {
float p = probs_bt[i];
float indicator = i == ix ? 1.0f : 0.0f;
dlogits_bt[i] += (p - indicator) * dloss;
}
}
}
}
交叉熵对 \(p\) 的梯度为
\[\frac{\partial y}{\partial p_i} = \frac{\partial y}{\partial loss} \frac{\partial loss}{\partial p_i} = -\log q_i \cdot loss\]注意这里的梯度回传只考虑词表中的 $V$ 个正常 token, 不考虑 padding token.
然后是 matmul 的反向传播
// backward into inp first, parallelize over B,T
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* dinp_bt = dinp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float* wrow = weight + o*C;
float d = dout_bt[o];
for (int i = 0; i < C; i++) {
dinp_bt[i] += wrow[i] * d;
}
}
}
}
第一步先计算下一步的回传梯度, 也即
\[\frac{\partial y}{\partial x_i} = \sum_j \frac{\partial y}{\partial y^{out}_j} \frac{\partial y^{out}_j}{\partial x_i} = \left(\frac{\partial y}{\partial y^{out}} \frac{\partial y^{out}}{\partial x}\right)_i\]其中 \(\frac{\partial y}{\partial y^{out}}\in \mathbb{R}^{B\times T\times 1\times OC}\), \(\frac{\partial y^{out}}{\partial x}\in \mathbb{R}^{B\times T\times OC\times C}\), 最终得到 \(\frac{\partial y}{\partial x}\in \mathbb{R}^{B\times T\times C}\).
第二步计算对权重与偏置的梯度
#pragma omp parallel for
for (int o = 0; o < OC; o++) {
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dout_bt = dout + b * T * OC + t * OC;
float* inp_bt = inp + b * T * C + t * C;
float* dwrow = dweight + o*C;
float d = dout_bt[o];
if (dbias != NULL) { dbias[o] += d; }
for (int i = 0; i < C; i++) {
dwrow[i] += inp_bt[i] * d;
}
}
}
}
同上, 显然有
\[\frac{\partial y}{\partial W} = \sum_{b,t} \frac{\partial y}{\partial y^{out}}\cdot \frac{\partial y^{out}}{\partial W} = \sum_{b,t} x^T_{b,t}\frac{\partial y}{\partial y^{out}}\]其他的反向传播过程类似, 这里就不再赘述了.
CUDA 代码分析
最后再来看一下原始版本的 CUDA 代码, 这里做的优化可能相对少一点, 但是还是有一些值得学习的地方.
首先是 encoder_forward_kernel, 这个函数的实现非常简单, 就是简单地为每一个 output channel 分配一个 thread 来找对应的 embedding, 这样的实现会导致 kernel 的计算访存比很低, 所以性能增益应该不会很高.
__global__ void encoder_forward_kernel2(float* out,
int* inp, float* wte, float* wpe,
int B, int T, int C) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int N = B * T * C;
if (idx < N) {
int bt = idx / C;
int b = bt / T;
int t = bt % T;
int c = idx % C;
int ix = inp[b * T + t];
float* out_btc = out + b * T * C + t * C + c;
float* wte_ix = wte + ix * C + c;
float* wpe_tc = wpe + t * C + c;
*out_btc = *wte_ix + *wpe_tc;
}
}
不妨再跟最新版本对比一下
x128 packed_out;
x128 wte128 = load128cs(wte_ix);
x128 wpe128 = load128cs(wpe_tc);
for (int k = 0; k < x128::size; k++) {
packed_out[k] = (floatX)((float)wte128[k] + (float)wpe128[k]);
}
store128(out_btc, packed_out);
可以看到是把多个 channel packed 在一起, 从而提升了计算效率.
随后是 layer norm 的实现, layer_norm_forward
里面调用了 mean_kernel
, rstd_kernel
和 normalization_kernel
三个 kernel, 实现分别是这样的.
_global__ void mean_kernel(float* mean, float* inp, int N, int C, int block_size) {
extern __shared__ float shared[];
int idx = blockIdx.x; // range [0, B*T)
int tid = threadIdx.x; // range [0, block_size)
float* x = inp + idx * C;
// thread coarsening
float sum = 0.0f;
for (int i = tid; i < C; i += block_size) {
sum += x[i];
}
shared[tid] = sum;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
mean[idx] = shared[0] / C;
}
}
这里的重点在于 reduction 的二分实现, 每一次迭代都干掉一半的 thread, 由他的一半来接管他的 sum, 最后归于到 thread 0 上完成 reduction. 在新版的实现里就直接调了内置函数 warpReduceSum
来完成 reduction.
__global__ void rstd_kernel(float* rstd, float* inp, float* mean, int N, int C, int block_size) {
extern __shared__ float shared[];
int idx = blockIdx.x; // range [0, B*T)
int tid = threadIdx.x; // range [0, block_size)
float* x = inp + idx * C;
float m = mean[idx];
// thread coarsening
float sum = 0.0f;
for (int i = tid; i < C; i += block_size) {
float diff = x[i] - m;
sum += diff * diff;
}
shared[tid] = sum;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
}
// write the final result (at thread 0) to global memory
if (tid == 0) {
rstd[idx] = 1.0f / sqrtf(shared[0] / C + 1e-5f);
}
}
rstd kernel 和上面的 mean kernel 几乎可以说没什么区别, 所以显然是可以做 fusion 的. 果不其然, 再去看新版本 kernel 就发现已经把三个 ln 和 residual fused 到一起了. normalization 也是一样, 这里就不多讲了.
下一步是 matmul, 这个应该猜都能猜到会直接用 cuBLAS 来实现