LLMs-From-Scatch总结

Stage 1

image-20260323010244791

创建词表

联系:

  1. 输入的文本需要被转化成Token
  2. 每个Token会对应词表的一个Token ID
  3. 这个ID可以用nn.Paramater(),这个本质上就是一个可以训练的矩阵,矩阵的行是词表的数量,矩阵的宽是Token embeddings(嵌入)的特征维度。
  4. 用nn.Parameter()实现的话好处就是直接通过Token ID来访问矩阵的第ID行特征,速度快,nn.Parameter()一般是用高斯分布进行初始化。
  5. 也可以使用nn.Linear()来实现Token embeddings,除了有bias之外,使用nn.Linear的好处是有更好的初始化,比如Xavier初始化或者Kaiming初始化。
    1. Xavier 初始化(适用于 tanh/sigmoid 激活),核心是让输入和输出的方差尽可能一致。
    2. Kaiming 初始化(适用于 ReLU/LeakyReLU 激活),核心是针对 ReLU 类激活函数的「死神经元」问题优化。

image-20260322222552069

一些换行,或者结尾,以及没见过的会进行特殊编码,举例如下:

image-20260322232304473

这些特殊编码会被add到Token Id的最后:

image-20260323005208804

BytePair encoding(BPE编码)

实际中,这些token Id并不是简单的对单词进行one-hot编码,而是将单词进行拆分,然后根据字符结合的频次进行编码,这个技术叫做BPE编码:

  1. 先把所有词拆成单个字符
  2. 统计所有相邻字符对的出现频次。
  3. 频次最高的那一对,合并成一个新符号。
  4. 重复步骤 2–3,直到达到你想要的词表大小。

至于和字节的关系则是因为,他把单词拆成字节的基础单元,也就是将通过字节编码(UTF-8/ASCII)映射字符串,无论是中文还是英文。合并高频字符串的本质就是合并高频字节对。

image-20260322233928925

位置编码

另外这些token embedding还需要有一个位置编码信息才会被送进网络:

image-20260323010041618

如果训练是一个文本预测任务,所以dataloader可以使用滑动窗口来采样训练数据:

image-20260323010134797

Stage 2

image-20260323010348123

Self-Attention

每个token embedding通过不同的可学习的权重矩阵分别计算得到QKV,然后QK得到一个注意力矩阵,再和V进行相乘得到V。

image-20260323010522680

Causal self-attention(因果注意力)

构建一个上三角掩码矩阵(下三角可见、上三角不可见),遮挡掉当前 token 对「未来位置」的注意力计算;计算注意力权重时,被掩码的位置权重会被设为极小值(如 -1e9),经过 Softmax 后几乎为 0,相当于 “看不见” 未来信息。

image-20260323010744172

Dropout

在实际操作中还可以通过Dropout来减少过拟和。

image-20260323010949203

实现1(CausalAttention)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class CausalAttention(nn.Module):
def __init__(self, d_in, d_out, context_length,
dropout, qkv_bias=False):
super().__init__()
self.d_out = d_out
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.dropout = nn.Dropout(dropout) # New
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New

def forward(self, x):
b, num_tokens, d_in = x.shape # New batch dimension b
keys = self.W_key(x)
queries = self.W_query(x)
values = self.W_value(x)

attn_scores = queries @ keys.transpose(1, 2) # Changed transpose
attn_scores.masked_fill_( # New, _ ops are in-place
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, dim=-1
)
attn_weights = self.dropout(attn_weights) # New

context_vec = attn_weights @ values
return context_vec

多头注意力

把注意力分成多组,让模型同时从不同角度、不同子空间去理解语义。实现上就是通过多组的WqWk, Wv来实现不同子空间的映射。

image-20260323011118768

实现2(stack多个头)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class Ch03_MHA(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

def forward(self, x):
b, num_tokens, d_in = x.shape

keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)

# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)

# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head

# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection

return context_vec

实现3(Combined Weight)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
class MultiHeadAttentionCombinedQKV(nn.Module):
def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"

self.num_heads = num_heads
self.context_length = context_length
self.head_dim = d_out // num_heads

self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
self.proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)

self.register_buffer(
"mask", torch.triu(torch.ones(context_length, context_length), diagonal=1)
)

def forward(self, x):
batch_size, num_tokens, embed_dim = x.shape

# (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
qkv = self.qkv(x)

# (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)

# (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)

# (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)
queries, keys, values = qkv.unbind(0)

# (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(-2, -1)
attn_scores = attn_scores.masked_fill(
self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)
attn_weights = self.dropout(attn_weights)

# (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)
context_vec = attn_weights @ values

# (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2)

# (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)
context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)

context_vec = self.proj(context_vec)

return context_vec

实现4(einsum)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
class MHAEinsum(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads

# Initialize parameters for Q, K, V
self.W_query = nn.Parameter(torch.randn(d_out, d_in))
self.W_key = nn.Parameter(torch.randn(d_out, d_in))
self.W_value = nn.Parameter(torch.randn(d_out, d_in))

if qkv_bias:
self.bias_q = nn.Parameter(torch.zeros(d_out))
self.bias_k = nn.Parameter(torch.zeros(d_out))
self.bias_v = nn.Parameter(torch.zeros(d_out))
else:
self.register_parameter("bias_q", None)
self.register_parameter("bias_k", None)
self.register_parameter("bias_v", None)

self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

# Initialize parameters
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.W_query, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.W_key, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.W_value, a=math.sqrt(5))
if self.bias_q is not None:
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_query)
bound = 1 / math.sqrt(fan_in)
nn.init.uniform_(self.bias_q, -bound, bound)
nn.init.uniform_(self.bias_k, -bound, bound)
nn.init.uniform_(self.bias_v, -bound, bound)

def forward(self, x):
b, n, _ = x.shape

# Calculate Q, K, V using einsum, first perform linear transformations
Q = torch.einsum("bnd,di->bni", x, self.W_query)
K = torch.einsum("bnd,di->bni", x, self.W_key)
V = torch.einsum("bnd,di->bni", x, self.W_value)

# Add biases if they are used
if self.bias_q is not None:
Q += self.bias_q
K += self.bias_k
V += self.bias_v

# Reshape for multi-head attention
Q = Q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)

# Scaled dot-product attention
scores = torch.einsum("bhnd,bhmd->bhnm", Q, K) / (self.head_dim ** 0.5)

# Apply mask
mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)
scores = scores.masked_fill(mask.bool(), -torch.inf)

# Softmax and dropout
attn_weights = torch.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)

# Aggregate the attended context vectors
context_vec = torch.einsum("bhnm,bhmd->bhnd", attn_weights, V)

# Combine heads and project the output
context_vec = context_vec.transpose(1, 2).reshape(b, n, self.d_out)
context_vec = self.out_proj(context_vec)

return context_vec

实现5(flash-Attention接口)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
class MHAPyTorchScaledDotProduct(nn.Module):
def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "embed_dim is indivisible by num_heads"

self.num_heads = num_heads
self.context_length = context_length
self.head_dim = d_out // num_heads
self.d_out = d_out

self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
self.proj = nn.Linear(d_out, d_out)
self.dropout = dropout

def forward(self, x):
batch_size, num_tokens, embed_dim = x.shape

# (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
qkv = self.qkv(x)

# (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)

# (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)

# (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
queries, keys, values = qkv

use_dropout = 0. if not self.training else self.dropout

context_vec = nn.functional.scaled_dot_product_attention(
queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)

# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)

context_vec = self.proj(context_vec)

return context_vec

性能对比

Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)

compilation(编译) 指的是「将 PyTorch 等框架的动态图代码,编译为 GPU 可直接执行的优化机器码」的过程,是实现 FlashAttention 这类高性能算子加速的核心步骤。

image-20260323012529668