斯坦福 CS336 动手大语言模型 Assignment1 BPE Tokenizer TransformerLM
所有代码更新至 https://github.com/WangYuHang-cmd/CS336/tree/main/assignment1-basics
作业文件结构:
CS336/assignment1-basics/
├── tests/ # 测试文件目录
│ ├── adapters.py # 适配器测试
│ ├── conftest.py # pytest配置
│ ├── __init__.py # 包初始化
│ ├── snapshots/ # 测试快照
│ ├── test_data.py # 数据处理测试
│ ├── test_model.py # 模型测试
│ ├── test_nn_utils.py # 神经网络工具测试
│ ├── test_optimizer.py # 优化器测试
│ ├── test_serialization.py # 序列化测试
│ ├── test_tokenizer.py # 分词器测试
│ └── test_train_bpe.py # BPE训练测试
│
└── cs336_basics/ # 实现文件目录├── attention.py # 注意力机制实现├── embedding.py # 嵌入层实现├── linear.py # 线性层实现├── optimizer.py # 优化器实现├── tokenizer.py # 分词器实现├── transformerLM.py # Transformer语言模型├── rope.py # RoPE位置编码├── rmsnorm.py # RMSNorm层├── softmax.py # Softmax实现├── swiglu.py # SwiGLU激活函数├── utils.py # 工具函数└── debug_*.py # 调试文件
BPE Tokenizer
BPE Class
首先是BPE类, 我们需要正确处理作业已经定义好的接口:
class BPETokenizer:def __init__(self, vocab_size: int, special_tokens: list[str] | None = None):self.vocab_size = vocab_sizeself.special_tokens = special_tokens or []self.special_tokens_bytes = [token.encode("utf-8") for token in self.special_tokens]self.merges: List[Tuple[bytes, bytes]] = []self.stoi: Dict[bytes, int] = {}self.itos: Dict[int, bytes] = {}self.merges_rank: Dict[Tuple[bytes, bytes], int] = {}# init vocabfor i, token_bytes in enumerate(self.special_tokens_bytes): # special tokensself.stoi[token_bytes] = iself.itos[i] = token_bytesoffset = len(self.special_tokens_bytes)for i in range(256):self.stoi[bytes([i])] = i + offsetself.itos[i + offset] = bytes([i])self.vocab = self.itos.copy() # for serializationself.merges_rank = {} # for fast lookup# pair2new: (p1, p2) -> new_token_idself.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}
其中stoi用来记录每一个toekn对应的token id, itos用来记录每一个token id对应的token, 在初始化的时候我们需要首先载入所有的special_tokens然后再依次将0-255对应字节值载入。
BPE Training
BPE Tokenizer是一个从data中进行学习的一个分词器,其以Byte为单位进行学习, 然后最终学校的结果包括了单词,词根等各种各样的形式。
BPE Tokenizer的核心就是首先经过预分词得到一个token列表, 此时全文被拆成了多个pre_token组成的列表, 然后对这个列表中的special token进行提取(special token不参与合并),我们得到由一整个大列表拆出来的多个小列表,然后我们需要依次统计每一个小列表中的前后相邻的字符pair的个数并计数, 然后按照以下规则进行合并:
1. 首先找到pair计数最多的pair <token1, token2>, 可能会有多个一样数量的pair
2. 然后优先找token1字典序更大的,进行合并
3. 其次找token2字典序更大的进行合并
- Pre_tokenize
pretokenize这个函数主要用来将文本切分成规范的词块列表,例如
GPT2_SPLIT_PATTERN = (r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
)
def pretokenize(text: str) -> list[bytes]:str_tokens = re.findall(GPT2_SPLIT_PATTERN, text)byte_tokens = [s.encode("utf-8") for s in str_tokens]return byte_tokens
例如"Hello world, this is user-123!" 会被pretokenize转换为 [‘Hello’, ’ world’, ‘,’, ’ this’, ’ is’, ’ user’, ‘-’, ‘123’, ‘!’]
Train
我们可以很轻易写出一个暴力的训练方法(见代码中的slow_train函数), 在这个函数中我们
num_merges_needed = self.vocab_size - len(self.stoi) # 需要合并的次数, 每一次合并会扩大vocab_size
for merge_cnt in range(num_merges_needed):pair_counts = self._get_stats(token_groups) # 遍历当前的训练列表,统计所有相邻token_id的个数best_pair = max( pair_counts,key=lambda p: (pair_counts[p], self.itos[p[0]], self.itos[p[1]]),) # 按照合并规则找到需要合并的pair# 更新合并后的所有字典new_token_id = len(self.itos)p1_bytes, p2_bytes = self.itos[best_pair[0]], self.itos[best_pair[1]]new_token_bytes = p1_bytes + p2_bytesself.merges.append((p1_bytes, p2_bytes))self.stoi[new_token_bytes] = new_token_idself.itos[new_token_id] = new_token_bytes...
但是这不仅有可能无法通过tests/test_train_bpe.py::test_train_bpe_speed测试(我的暴力解法大约使用了5.8s远大于限制的1.5秒), 在tests/test_train_bpe.py::test_train_bpe_special_tokens 测试中大约使用了将近7分钟。
==================================================== 1 failed, 2 passed in 476.21s (0:07:56) ====================================================
因此我们需要考虑优化这个合并的过程:耗时的大头是 1. “每一次都需要重新统计所有pair” 2.“更新后需要每一次重写当前的token_id序列”, 而这些都可以通过数据结构来优化:对于token_id序列我们可以使用双向链表来构建,然后对于每一个token_id对应的列表的节点位置我们可以存储到token_id为key的set中。然后我们只需要扫一遍整个token_id的序列, 记录每一个pair的个数然后全部push到一个堆中, 这个堆每一次会从堆顶优先pop出我们需要合并的pair. 记住,这里我们并不需要在合并pair后从内部修改这个堆,我们只需要pop出来的时候判断一下当前的pair是否存在或者其计数是否和我们的pair_counts中一致即可。此一次修改合并后我们也只需要将合并后的pair对应的计数重新push进堆即可。考虑到每一次修改的数量不会很多, 因此总的复杂度大约是nlogn级别的
综上我们的思路是:
数据结构:
- 维护一个大根堆,里面维护按照BPE合并的顺序进行排序的token_id pair
- 维护一个双向链表 用来记录当前的token_id序列
- 维护一个为每一个token_id维护一个set,用来存储每一个token_id对应的所有的双向链表的节点的位置
更新方式:
1. 从heap顶部取出token_id的pair,判断是否和pair_count中记录的数量一致,若不一致则找下一个,直到一致为止,此时就是需要合并的BPE Pair
2. 通过heap中记录的链表节电找到当前当前pair的 pos_idx, nxt_idx然后找到向前向后的链表pre_idx和nnxt_idxpre_idx <-> pos_idx <-> nxt_idx <-> nnxt_idx 我们合并后会变成pre_idx <-> (pos_idx,nxt_idx) <-> nnxt_idxnew_token = token[pos_idx] + token[nxt_idx]3. 更新pair_count,遍历pos[token[pos_idx]]的所有链表节点,找到所有nxt[]对应的token_id是token[nxt_idx]的位置,然后删除这些位置pair_count[(token[pre_idx], token[pos_idx])] - 1 pair_count[(token[pre_idx], new_token)] + 1pair_count[(token[nxt_idx], token[nnxt_idx])] - 1 pair_count[(new_token, token[nnxt_idx])] + 1pos[new_token].add(pos_idx)pre[nnxt_idx] = pos_idxnxt[pos_idx] = nnxt_idxpre[nxt_idx] = nxt[nxt_idx] = None # 删除被合并的pair中靠后的那一个token对应的链表
由于python中的heapq默认使用小根堆, 因此我们需要重写一个类来实现大根堆
class PairItem:def __init__(self, count, token_id1, token_id2, itos):self.count = countself.token_id1 = token_id1self.token_id2 = token_id2self.itos = itosself.bytes1 = itos[token_id1]self.bytes2 = itos[token_id2]def __lt__(self, other):# 首先按频次降序(大的在前)if self.count != other.count:return self.count > other.count# 频次相同时,按第一个token的字节降序if self.bytes1 != other.bytes1:return self.bytes1 > other.bytes1# 第一个token相同时,按第二个token的字节降序return self.bytes2 > other.bytes2def __eq__(self, other):return (self.count == other.count and self.bytes1 == other.bytes1 and self.bytes2 == other.bytes2)def get_pair(self):return (self.token_id1, self.token_id2)
然后我们读取文本直到处理好pretokenize的结果后
# Pre-Tokenizer
assert self.vocab_size >= len(self.stoi)with open(path, "r", encoding="utf-8") as f:text = f.read()if self.special_tokens: # Special Tokenspecial_pattern = f"({'|'.join(re.escape(s) for s in self.special_tokens)})"text_parts = re.split(special_pattern, text)
else:text_parts = [text]# Pre-Tokenizer
initial_vocab_map = {v: k for k, v in self.itos.items()}
token_groups = []
for part in text_parts:if part in self.special_tokens or not part:continuewords_in_bytes = pretokenize(part)for word in words_in_bytes:token_groups.append([initial_vocab_map[bytes([b])] for b in word])
首先只需要扫一遍整体的token_id序列进行统计:
# BPE Merge
idx = 0
pair_counts = {}
token = {}
pre = {}
nxt = {}
pos = {}for i, token_lst in enumerate(token_groups):if not token_lst or len(token_lst) <= 1:continuetoken_lst_len = len(token_lst)for j, token_id in enumerate(token_lst):idx += 1token[idx] = token_idnxt[idx] = None if j == token_lst_len - 1 else idx + 1pre[idx] = None if j == 0 else idx - 1if j == token_lst_len - 1:continuetoken_pair = (token_id, token_lst[j + 1])pair_counts[token_pair] = pair_counts.get(token_pair, 0) + 1if pos.get(token_pair) is None:pos[token_pair] = set()pos[token_pair].add(idx)heap = []
for (a, b), cnt in pair_counts.items():item = PairItem(cnt, a, b, self.itos)heapq.heappush(heap, item)
然后我们可以开始BPE Merge,merge的顺序和细节需要十分注意,尤其是更新的顺序和对于是否更新的还存在的pair的判断
def update_pair(pair: tuple[int, int], delta: int, pos_idx: int | None = None):if pair is None or None in pair: returnpair_counts[pair] = pair_counts.get(pair, 0) + deltacnt = pair_counts[pair]if cnt <= 0:pair_counts.pop(pair, None)pos.pop(pair, None)returnif pos_idx is not None:ds = pos.setdefault(pair, set())if delta > 0:ds.add(pos_idx)elif delta < 0:ds.discard(pos_idx)a, b = pairitem = PairItem(cnt, a, b, self.itos)heapq.heappush(heap, item)num_merges_needed = self.vocab_size - len(self.stoi)
while num_merges_needed > 0 and heap:if not pair_counts: breaknum_merges_needed -= 1while heap:item = heapq.heappop(heap)p1, p2 = item.get_pair()# 检查这个 pair 是否仍然有效if (p1, p2) not in pair_counts or pair_counts[(p1, p2)] != item.count:continue # 已经被合并过了# merge the new tokenself.merges.append((self.itos[p1], self.itos[p2]))p1_bytes, p2_bytes = self.itos[p1], self.itos[p2]new_token_bytes = p1_bytes + p2_bytesnew_token_id = (len(self.stoi)if self.stoi.get(new_token_bytes) is Noneelse self.stoi[new_token_bytes])self.stoi[new_token_bytes] = new_token_idself.itos[new_token_id] = new_token_bytespos_lst = list(pos.get((p1, p2), set()))# modify the token groupfor pos_idx in pos_lst:pre_idx = pre[pos_idx]nxt_idx = nxt[pos_idx]nnxt_idx = nxt[nxt_idx] if nxt_idx is not None else Noneif nxt_idx is None or token[pos_idx] != p1 or token[nxt_idx] != p2: continueif pre_idx is not None:nxt[pre_idx] = pos_idx # keep unchangedupdate_pair((token[pre_idx], token[pos_idx]), -1, pre_idx)update_pair((token[pre_idx], new_token_id), 1, pre_idx)if nnxt_idx is not None:pre[nnxt_idx] = pos_idxupdate_pair((token[nxt_idx], token[nnxt_idx]), -1, nxt_idx)update_pair((new_token_id, token[nnxt_idx]), 1, pos_idx)pre[pos_idx] = pre_idxnxt[pos_idx] = nnxt_idxtoken[pos_idx] = new_token_idtoken[nxt_idx] = None # remove the old tokenpre[nxt_idx] = Nonenxt[nxt_idx] = Nonepair_counts.pop((p1, p2), None)pos.pop((p1, p2), None)breakself.merges_rank = {pair: i for i, pair in enumerate(self.merges)}
self.vocab = self.itos.copy()
self.pair2new = {(p1, p2): self.stoi[p1 + p2] for (p1, p2) in self.merges}
然后测试发现最终用时会快很多
============================================================== 3 passed in 30.85s ====================================
其中对于第一个测试从
# 暴力用时
(1752185555.6502326 - 1752185549.8956482) < 1.5
# 优化之后
tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds
当然除了重载这个堆内的排序方式外,我们还可以手动来写比较字符串时的一个比较方式,只不过需要注意的是我们需要在短的序列末尾补大字符直到和长的一样长(可以手动指定max_len为一个比较大的数,这个的速度也很快)
def bytes_desc(b):return bytes(255 - x for x in b)def pair_desc(pair):a = self.itos[pair[0]]b = self.itos[pair[1]]max_len = 2a_pad = a + bytes([0] * (max_len - len(a)))b_pad = b + bytes([0] * (max_len - len(b)))return (bytes_desc(a_pad), bytes_desc(b_pad))heap = [(-cnt, # 频次取负,freq 高 → 数值小pair_desc((a, b)),a, b,) # token-1 id, token-2 idfor (a, b), cnt in pair_counts.items()
]
heapq.heapify(heap)
BPE Encode & Decode
首先是Encode部分, 这个部分需要我们将输入的文本字符串转换为整数ID序列,然后我们需要注意在处理的时候1.特殊token优先处理:先识别并保护特殊token(如<|endoftext|>)2. 按长度排序:避免短特殊token被长特殊token包含的情况 3.分段处理:将文本分割为特殊token和普通文本段落.
我们首先来完成不含有special token的encoder:
def _encode_ordinary_text(self, text_bytes: bytes) -> list[int]:if not text_bytes:return []try:text = text_bytes.decode("utf-8")except UnicodeDecodeError:text = text_bytes.decode("utf-8", errors="replace")ids_out = array("H") # uint16 足够 ≤ 65k vocabpair_rank = self.merges_rankpair2new = self.pair2newbyte2id = self.stoi # 局部 alias,加速# 逐个“词块”处理,避免一次性 listfor word_b in iter_pretokenize(text):token_ids = array("H", (byte2id[bytes([b])] for b in word_b))# b. 就地合并:“greedy smallest-rank merge”while True:best_rank = 1000000000best_pos = -1# ——— 找当前序列里 rank 最小的 pair ———for i in range(len(token_ids) - 1):r = pair_rank.get( # ——— 替换 best_pos & best_pos+1 为新的 token ———(self.itos[token_ids[i]], self.itos[token_ids[i + 1]]),1000000000,)if r < best_rank:best_rank, best_pos = r, iif best_pos == -1:breaknew_id = pair2new[(self.itos[token_ids[best_pos]], self.itos[token_ids[best_pos + 1]])]token_ids[best_pos : best_pos + 2] = array("H", [new_id])ids_out.extend(token_ids)# array → listreturn ids_out.tolist()
在这里我使用了array而不是list,这样每个token_id只占用2字节,逐个字符处理是防止内存爆炸
然后处理带有特殊字符的encoder:
def encode(self, text: str) -> list[int]:"""Encode str"""if not text:return []sorted_special_tokens = sorted(self.special_tokens, key=len, reverse=True)if not sorted_special_tokens:return self._encode_ordinary_text(text.encode("utf-8"))special_pattern = f"({'|'.join(re.escape(s) for s in sorted_special_tokens)})"text_parts = re.split(special_pattern, text)all_ids = []for part in text_parts:if part in self.special_tokens:all_ids.append(self.stoi[part.encode("utf-8")])elif part:all_ids.extend(self._encode_ordinary_text(part.encode("utf-8")))return all_ids
对于decode函数则很简单, 我们需要将一个token id序列转换成字符串,按照BPE训练时的合并顺序:
def decode(self, ids: list[int]) -> str:all_bytes = b"".join(self.itos.get(id, b"") for id in ids)return all_bytes.decode("utf-8", errors="replace")
最后我们需要对BPETokenizer这个类进行一个序列化:
@classmethoddef from_serialized(cls,vocab: dict[int, bytes],merges: list[tuple[bytes, bytes]],special_tokens: list[str],):instance = cls(vocab_size=len(vocab), special_tokens=special_tokens)instance.stoi = {v: k for k, v in vocab.items()}instance.itos = vocabinstance.merges = mergesinstance.merges_rank = {pair: i for i, pair in enumerate(merges)}instance.vocab = vocabinstance.pair2new = {(p1, p2): instance.stoi[p1 + p2] for (p1, p2) in merges}return instance
测试结果 (注意最后一个点的XFail是正常的 说明你没有作弊…)
============================================================== 3 passed in 30.85s ===============================================================
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_train_bpe.py tests/test_train_bpe.py::test_train_bpe_speed time using toy implementation: 0.32 seconds
PASSED
tests/test_train_bpe.py::test_train_bpe PASSED
tests/test_train_bpe.py::test_train_bpe_special_tokens PASSED
(llm) henry@motif-gpu:~/Desktop/LLM/CS336/assignment1-basics$ python -m pytest -q tests/test_tokenizer.py tests/test_tokenizer.py::test_roundtrip_empty PASSED
tests/test_tokenizer.py::test_empty_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_character PASSED
tests/test_tokenizer.py::test_single_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_single_unicode_character PASSED
tests/test_tokenizer.py::test_single_unicode_character_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_ascii_string PASSED
tests/test_tokenizer.py::test_ascii_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string PASSED
tests/test_tokenizer.py::test_unicode_string_matches_tiktoken PASSED
tests/test_tokenizer.py::test_roundtrip_unicode_string_with_special_tokens PASSED
tests/test_tokenizer.py::test_unicode_string_with_special_tokens_matches_tiktoken PASSED
tests/test_tokenizer.py::test_overlapping_special_tokens PASSED
tests/test_tokenizer.py::test_address_roundtrip PASSED
tests/test_tokenizer.py::test_address_matches_tiktoken PASSED
tests/test_tokenizer.py::test_german_roundtrip PASSED
tests/test_tokenizer.py::test_german_matches_tiktoken PASSED
tests/test_tokenizer.py::test_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_special_token_trailing_newlines PASSED
tests/test_tokenizer.py::test_encode_special_token_double_newline_non_whitespace PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_sample_roundtrip PASSED
tests/test_tokenizer.py::test_encode_iterable_tinystories_matches_tiktoken PASSED
tests/test_tokenizer.py::test_encode_iterable_memory_usage PASSED
tests/test_tokenizer.py::test_encode_memory_usage XFAIL (Tokenizer.encode is expected to take more memory than allotted (1MB).)========================================================= 24 passed, 1 xfailed in 4.50s =========================================================
TransformerLM
对于transformerLM我认为着一块的难度比较常规,跟着课程的pdf照着写就可以,不过很适合用来熟悉einops中einsum, reduce和rearrange的用法。以下是一些需要注意的地方。
Rope
这里forward可能会有精度问题,因此需要首先转成torch.float32然后再转回去即可
class RoPE(nn.Module):def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None,dtype=None):super().__init__()self.theta = thetaself.d_k = d_kself.max_seq_len = max_seq_lenself.half_dim = d_k // 2freq_seq = torch.arange(self.half_dim, dtype=torch.float32, device=device)inv_freq = 1.0 / (theta ** (freq_seq / self.half_dim))t = torch.arange(max_seq_len, dtype=torch.float32, device=device)freqs = einsum(t, inv_freq, "i, j -> i j")cos = torch.cos(freqs)sin = torch.sin(freqs)self.register_buffer("cos_cached", cos, persistent=False)self.register_buffer("sin_cached", sin, persistent=False)def forward(self,x: Float[Tensor, "... seq_len d_k"],token_positions: Int[Tensor, "... seq_len"],) -> Float[Tensor, "... seq_len d_k"]:assert x.shape[-1] == self.d_k, f"x's last dim {x.shape[-1]} != d_k {self.d_k}"assert self.d_k % 2 == 0, "d_k must be even for RoPE"in_type = x.dtypex = x.to(torch.float32)# (... seq_len d_k) -> (... seq_len d_pair 2) 2D-Tensorx_pair = rearrange(x, "... seq_len (d_pair two) -> ... seq_len d_pair two", two = 2)# cos/sin tensor buildcos = self.cos_cached[token_positions]sin = self.sin_cached[token_positions]rot_mat = torch.stack((torch.stack((cos, -sin), dim = -1),torch.stack((sin, cos), dim = -1),),dim = -2,)# rotate "i j, j -> i"x_rot = einsum(rot_mat, x_pair, "... d_pair i j, ... d_pair j -> ... d_pair i")out = rearrange(x_rot, "... seq_len d_pair two -> ... seq_len (d_pair two)", two = 2)return out.to(in_type)
TransformerBlock
TransformerBlock按照pdf的要求写,注意模块的复用
class TransformerBlock(nn.Module):def __init__(self,d_model: int,num_heads: int,d_ff: int,max_seq_len: int,theta: float,device=None,dtype=None,):super().__init__()assert d_model % num_heads == 0, "d_model must be divisible by num_heads"self.ln1 = RMSNorm(d_model, device=device, dtype=dtype)self.attn = MultiheadSelfAttentionWithRoPE(d_model, num_heads, max_seq_len, theta, device, dtype)self.ln2 = RMSNorm(d_model, device=device, dtype=dtype)self.ffn = SwiGLUFFN(d_model, d_ff, device, dtype)def forward(self,x: Float[Tensor, "batch seq_len d_model"],token_positions: Int[Tensor, "batch seq_len"] | None = None,) -> Float[Tensor, "batch seq_len d_model"]:if token_positions is None:token_positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), -1)x = x + self.attn(self.ln1(x), token_positions)x = x + self.ffn(self.ln2(x))return x
TransformerLM
这里最后不需要返回softmax之后的logits, 返回softmax前一层的tensor即可
class TransformerLM(nn.Module):def __init__(self,vocab_size: int,d_model: int,num_heads: int,d_ff: int,context_length: int,theta: float,num_layers: int,device=None,dtype=None,):super().__init__()self.vocab_size = vocab_sizeself.d_model = d_modelself.num_heads = num_headsself.d_ff = d_ffself.context_length = context_lengthself.theta = thetaself.num_layers = num_layersself.device = deviceself.dtype = dtypeparam_dtype = (dtypeif (dtype is not Noneand torch.is_floating_point(torch.tensor([], dtype=dtype)))else torch.float32)self.token_embeddings = Embedding(vocab_size, d_model, device=device, dtype=param_dtype)self.layers = MyLayerList([TransformerBlock(d_model=d_model,num_heads=num_heads,d_ff=d_ff,max_seq_len=context_length,theta=theta,device=device,dtype=param_dtype,)for _ in range(num_layers)])self.ln_final = RMSNorm(d_model, device=device, dtype=param_dtype)self.lm_head = Linear(d_model, vocab_size, device=device, dtype=param_dtype)@torch.no_grad()def forward(self,input_indices: Int[Tensor, "batch seq_len"],token_positions: Int[Tensor, "batch seq_len"] | None = None,) -> Float[Tensor, "batch seq_len vocab_size"]:x = self.token_embeddings(input_indices)if token_positions is None:token_positions = torch.arange(x.size(1), device=x.device).expand(x.size(0), -1)for layer in self.layers:x = layer(x, token_positions)x = self.ln_final(x)logits = self.lm_head(x)return logits
get_batch
get_batch的测试写的不是很完善,这里可以写成保证每一个Epoch rand的数据都不重复
def get_batch(dataset: npt.NDArray,batch_size: int,context_length: int,device: torch.device = torch.device("cpu"),
) -> tuple[npt.NDArray, npt.NDArray]:B, T = batch_size, context_lengthdata_t = torch.as_tensor(dataset, dtype=torch.long, device=device)N = data_t.numel()# starts = torch.randint(0, N - T, (B,), device=device)starts = torch.randperm(N - T, device=device)[:B] # 无放回采样offsets = rearrange(torch.arange(T + 1, device=device), 'n -> 1 n') # [1, T+1]positions = rearrange(starts, 'b -> b 1') + offsets tokens = data_t[positions] # [B, T+1]x, y = tokens[:, :-1], tokens[:, 1:] # Next token prediction [B, T]return x, yclass EpochSampler:def __init__(self, num_positions: int, device: torch.device):self.N = num_positions self.device = deviceself._shuffle() def _shuffle(self):self.perm = torch.randperm(self.N, device=self.device)self.cursor = 0 def next(self, k: int) -> torch.Tensor:if self.cursor + k > self.N: self._shuffle()idx = self.perm[self.cursor : self.cursor + k]self.cursor += kreturn idxdef get_batch_without_same(dataset: npt.NDArray,batch_size: int,context_length: int,sampler: EpochSampler,device: torch.device = torch.device("cpu"),
) -> tuple[torch.Tensor, torch.Tensor]:B, T = batch_size, context_lengthdata_t = torch.as_tensor(dataset, dtype=torch.long, device=device) # [N_total]N = data_t.numel()starts = sampler.next(B) # shape (B,)# offsets: [1, T+1],数值 0‥Toffsets = torch.arange(T + 1, device=device).unsqueeze(0) # (1, T+1)# positions: broadcast → (B, T+1)positions = starts.unsqueeze(1) + offsetstokens = data_t[positions] # (B, T+1)x, y = tokens[:, :-1], tokens[:, 1:] # (B, T)return x, y
此外我的代码仓库中还提供一些debug函数,可以用来debug tokenizer和bpe_train, 在cs336_basics文件夹下
最后帖一张全部通过的图片: