当前位置: 首页 > news >正文

nanoGPT复现——bench.py和sample.py

其实bench.py和sample.py是train.py的缩略版,主干是一样的,所以这里就没有把参数给分离出来,只是把代码的每个模块分离出来了,如果不懂参数的用处可以看train.py那一篇,那里更全面

一、bench.py

bench的主要作用是对模型的各项性能进行检测

参数

import os
import timeimport numpy as np
import torch
from contextlib import nullcontext
from model import GPT, GPTConfigbatch_size = 12
block_size = 1024
device = 'cuda'
real_data = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
bias = False
compile = True
profile = False
exec(open('configurator.py', encoding='utf-8').read())dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
ptdtype = {'float16':torch.float16, 'bfloat16':torch.bfloat16, 'float32':torch.float32}[dtype]
ctx = nullcontext if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)seed = 1337
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

数据生成

#---------------data-------------------------------
if real_data:dataset = 'shakeseare-char'data_path = os.path.join('Data', dataset)train_data = np.memmap(os.path.join(data_path, 'train.bin'), dtype=np.uint64, mode='r')def get_batch(split):data = train_dataidx = torch.randint(len(data) - batch_size, (batch_size,))x = torch.stack([torch.from_numpy((data[i:i+batch_size]).astype('int64')) for i in idx])y = torch.stack([torch.from_numpy((data[i+1:i+1+batch_size]).astype('int64')) for i in idx])x, y = x.pin_memory().to(device=device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)return x,y
else:x = torch.randint(50304, (batch_size, block_size), device=device)y = torch.randint(50304, (batch_size, block_size), device=device)get_batch = lambda split:(x, y)

model定义

#--------------------model-----------------------
gptconf = {'n_head':12,'n_embd':768,'n_layer':12,'dropout': 0.0,'block_size':block_size,'bias': bias
}
conf = GPTConfig(**gptconf)
model = GPT(conf)
model.to(device)
optimizer = model.configure_optimizers(weight_decay=1e-2 ,learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type)

compile

#------------------compile-------------------------
if compile:print("model is  compiling, take minutes~")model = torch.compile(model)

主干(profile)

#-------------------主干(profile)-------------------
if profile:wait, warmup, active = 5, 5, 5num_steps = wait + warmup + activewith torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1),on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'),record_shapes=False,profile_memory=False,with_stack=False,with_flops=True,with_modules=False) as proc:X, Y = get_batch('train')for k in range(num_steps):with ctx:logits, loss = model(X, Y)X, Y = get_batch('train')optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()lossf = loss.item()print(f"{k}/{num_steps} loss: {lossf:4f}")proc.step()
else:torch.cuda.synchronize()for stage, num_steps in enumerate([10, 20]):t0 = time.time()X, Y = get_batch('train')for k in range(num_steps):with ctx:logits, loss = model(X, Y)X, Y = get_batch('train')optimizer.zero_grad(set_to_none=True)loss.backward()optimizer.step()lossf = loss.item()print(f"{k}/{num_steps} loss : {lossf:.4f}")torch.cuda.synchronize()t1 = time.time()dt = t1 - t0mfu = model.estimate_mfu(fwdbwd_per_iter=batch_size*1*num_steps, dt=dt)if stage == 1:print(f"time per iteration:{dt/num_steps**1000:4f}ms, MFU:{mfu*100:.2f}")

二、sample.py

sample主要是利用定义的模型来生成文本

参数

import pickleimport tiktoken
import torch
from model import GPT, GPTConfig
import numpy as np
import os
from contextlib import nullcontextinit_from = 'resume'
out_path = 'out-shakespeare-char'
start = '\n'
max_new_tokens = 500
temperature = 0.8
top_k = 200
seed = 1337
num_samples = 10device = 'cuda'
device_type = 'cuda' if 'cuda' in device else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = False
exec(open('configurator.py', encoding='utf-8').read())
pydtype = {'float16': torch.float16, 'bfloat16': torch.bfloat16, 'float32': torch.float32}[dtype]torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.allow_tf32 = True
ctx = nullcontext if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=pydtype)

模型定义+compile

if init_from == 'resume':ckpt_path = os.path.join(out_path, 'ckpt.pt')checkpoint = torch.load(ckpt_path, map_location=device)config_args = checkpoint['model_args']config = GPTConfig(**config_args)model = GPT(config)state_dict = checkpoint['model']unwanted_name = '_orig_mod.'for k, v in list(state_dict.items()):if k.startswith(unwanted_name):state_dict[k[len(unwanted_name):]] = state_dict.pop(k)model.load_state_dict(state_dict)
elif init_from.startswith('gpt2'):model = GPT.from_pretrained(init_from, override_arg=dict(dropout=0.0))model.eval()
model.to(device)
if compile:model = torch.compile(model)

meta处理

meta.pkl 一般是用 pickle 库序列化保存的 Python 对象,通常包含模型、数据、词表、标签、配置信息的**“元信息”**,用于训练、推理或加载过程中的辅助。

在nanoGPT中meta是由自己构建词表时保存的词表信息,详见prepare.py的拆解

load_meta = False
if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']:meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')load_meta = os.path.exists(meta_path)if load_meta:with open(meta_path, 'rb') as f:meta = pickle.load(f)itos, stoi = meta['itos'], meta['stoi']encode = lambda x: [stoi[s] for s in x]decode = lambda l: ''.join([itos[i] for i in l])
else:enc = tiktoken.get_encoding('gpt2')encode = lambda x: enc.encode(x, allowed_special={"<|endoftext|>"})decode = lambda l: enc.decode(l)

主干

if start.startswith('FILE:'):file_path = start[5:]with open(file_path, 'r', encoding='utf-8') as f:start = f.read()
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
with torch.no_grad():with ctx:for k in range(num_samples):y = model.generate(x, max_new_tokens, temperature, top_k)print(decode(y[0].tolist()))

http://www.lryc.cn/news/578631.html

相关文章:

  • 【MobaXterm、Vim】使用合集1
  • 【科研绘图系列】基于R语言的复杂热图绘制教程:环境因素与染色体效应的可视化
  • 用lines_gauss的width属性提取缺陷
  • Prompt生成指南
  • Unity-ComputeShader
  • UE5.6 官方文档笔记 [1]——虚幻编辑器界面
  • C#.Net筑基-优雅LINQ的查询艺术
  • 6.2 实现文档加载和切分和简易向量数据库的功能
  • 图像处理专业书籍以及网络资源总结
  • beego打包发布到Centos系统及国产麒麟系统完整教程
  • 前端第二节(Vue)
  • 微信小程序实现table表格
  • 微信小程序21~30
  • CppCon 2018 学习:EFFECTIVE REPLACEMENT OF DYNAMIC POLYMORPHISM WITH std::variant
  • Linux->进程控制(精讲)
  • 《P5522 [yLOI2019] 棠梨煎雪》
  • 如何分析大语言模型(LLM)的内部表征来评估文本的“诚实性”
  • 在 Docker 容器中使用内网穿透
  • 大语言模型推理系统综述
  • NLP——RNN变体LSTM和GRU
  • 关于vue2使用elform的rules校验
  • 深度学习进阶:自然语言处理的推荐点评
  • (LeetCode 面试经典 150 题) 42. 接雨水 (单调栈)
  • Gartner《Choosing Event Brokers to Support Event-DrivenArchitecture》心得
  • 振荡电路Multisim电路仿真实验汇总——硬件工程师笔记
  • .NET跨平台开发工具Rider v2025.1——支持.NET 10、C# 14
  • K8s Pod调度基础——2
  • Langgraph 学习教程
  • 位运算经典题解
  • python+uniapp基于微信小程序的流浪动物救助领养系统nodejs+java