当前位置: 首页 > news >正文

LoRA微调的代码细节

本文主要是从代码角度进行解析。具体原理可以看这个

《LoRA:高效的深度学习模型微调技术及其应用》-CSDN博客

LoRA、ControlNet与T2I Adapter的区别-CSDN博客

1、核心原理

在 UNet 的交叉注意力层(Cross-Attention)中,将权重矩阵分解为低秩矩阵乘积(ΔW=BA),仅训练 A 和 B 矩阵。

  • UNet 交叉注意力层:LoRA 主要应用于交叉注意力的 Q、K、V 矩阵,控制文本特征与图像特征的对齐。
  • 多模态支持:在 Flux.1 等模型中,LoRA 可同时适配 CLIP 和 T5 文本编码器,实现双文本条件生成。

低秩分解数学原理

  • 对于to_q层(原始权重矩阵W0,维度m×n),插入后会新增两个小矩阵:
    • A:维度m×r(从输入维度m映射到低秩维度r);
    • B:维度r×n(从低秩维度r映射回输出维度n);
    • 原始权重W0保持冻结,仅适配器的矩阵A和矩阵B可训练。
    • 可训练参数数量由 LoRA 的秩r决定:对于单个原始矩阵W0m×n),LoRA 的参数为m×r + r×n = r×(m+n),远小于原始的m×n
  • to_q、to_v层同理

2、计算逻辑

插入前:q、k、v层的输出直接由原始权重计算:output = W0 × input(矩阵乘法)。

插入后:q、k、v层的输出是原始输出 + LoRA 调整量,公式为:

  • output = W0 × input + (α/r) × (B × (A × input))
  • W0 × input:原始预训练模型的输出(保证基础能力不丢失);
  • B × (A × input):LoRA 适配器的输出(低秩分解后的调整量,控制任务特异性变化);
  • α/r:缩放因子(代码中lora_alpha=32r=16,即缩放因子为32/16=2),平衡 LoRA 调整量的强度。
# 配置LoRA参数# 原始权重矩阵维度为 (m, n),LoRA 会将其分解为两个小矩阵:A(维度 m×r)和 B(维度 r×n),可训练参数仅为 (m+n)×r,远小于原始的 m×nunet_lora_config = LoraConfig(r=16,  # 低秩矩阵的秩(值越小参数越少,泛化性越好;越大表达能力越强,易过拟合)lora_alpha=32,  # 缩放因子, LoRA 对原始权重的更新公式为 W = W₀ + α/r × B·A,其中 α 即 lora_alpha,r 是秩。init_lora_weights="gaussian",# 目标模块:Stable Diffusion XL的UNet中负责注意力计算的层target_modules=["to_q",  # 注意力层的Query矩阵(查询)"to_k",  # 注意力层的Key矩阵(键)"to_v",  # 注意力层的Value矩阵(值)"to_out.0"  # 注意力输出的线性层(确保输出特征维度匹配)],bias="none",  # 默认,不微调偏置参数(减少参数量,降低过拟合风险))# 将LoRA适配器注入UNetunet.add_adapter(unet_lora_config)

插入前(全量微调):

  • 训练会更新所有参数,可能导致 “灾难性遗忘”(原始模型的通用能力被覆盖),且需要大量数据才能稳定训练。

插入后(LoRA 微调):

  • 仅更新 LoRA 适配器参数,原始参数冻结,保留预训练模型的通用能力(如基本图像生成、细节建模)。
  • 适配器专注学习 “新任务特异性知识”(如特定风格、角色特征),通过低秩分解限制参数容量,减少过拟合风险(尤其适合小数据集)。

3、简单从0实现自注意力机制,对比有无LoRA的区别

以下代码应该比较清晰的解释了在模型中LoRA的具体作用机制

import torch
import torch.nn as nn
import torch.nn.functional as F# ------------------------------
# 1. 图像特征映射层(将图像转为序列特征)
# ------------------------------
class ImageProjection(nn.Module):def __init__(self, img_size=224, in_chans=3, embed_dim=512, patch_size=16):super().__init__()self.patch_size = patch_sizeself.num_patches = (img_size // patch_size) **2  # 14x14=196个patch# 卷积层:将每个patch映射到512维self.proj = nn.Conv2d(in_channels=in_chans,out_channels=embed_dim,kernel_size=patch_size,stride=patch_size,bias=True)def forward(self, x):# x: (batch, 3, 224, 224) → 卷积后: (batch, 512, 14, 14)x = self.proj(x)# 展平为序列: (batch, 512, 196) → (batch, 196, 512)x = x.flatten(2).transpose(1, 2)return x# ------------------------------
# 2. 原始单头自注意力
# ------------------------------
class SelfAttention(nn.Module):def __init__(self, embed_dim=512):super().__init__()self.embed_dim = embed_dim  # 特征维度# Q、K、V三个独立的线性层self.to_q = nn.Linear(embed_dim, embed_dim)  # 输入维度→Q维度self.to_k = nn.Linear(embed_dim, embed_dim)  # 输入维度→K维度self.to_v = nn.Linear(embed_dim, embed_dim)  # 输入维度→V维度# 输出投影层self.to_out = nn.Linear(embed_dim, embed_dim)def forward(self, x):# 输入形状:(batch_size, seq_len, embed_dim)batch_size, seq_len, _ = x.shape# 1. 计算Q、K、V(核心步骤)q = self.to_q(x)  # (batch, seq_len, embed_dim)k = self.to_k(x)v = self.to_v(x)# 2. 计算注意力分数(scaled dot-product)# 公式:注意力分数 = (Q × K^T) / sqrt(embed_dim)attn_scores = torch.matmul(q, k.transpose(-2, -1))  # (batch, seq_len, seq_len)attn_scores = attn_scores / (self.embed_dim ** 0.5)  # 缩放,避免梯度消失# 3. 计算注意力概率分布(softmax归一化)attn_probs = F.softmax(attn_scores, dim=-1)  # (batch, seq_len, seq_len)# 4. 加权求和得到注意力输出(V × 注意力概率)attn_output = torch.matmul(attn_probs, v)  # (batch, seq_len, embed_dim)# 5. 输出投影return self.to_out(attn_output)# ------------------------------
# 3. LoRA适配器
# ------------------------------
class LoRALayer(nn.Module):def __init__(self, in_dim, out_dim, rank=8, alpha=16):super().__init__()self.rank = rankself.scaling = alpha / rank  # 缩放因子:控制LoRA对原始输出的影响强度# 低秩分解矩阵(核心创新点),注意bias=Falseself.A = nn.Linear(in_dim, rank, bias=False)  # 降维矩阵:in_dim → rank(低秩空间)self.B = nn.Linear(rank, out_dim, bias=False)  # 升维矩阵:rank → out_dim(还原维度)# 初始化策略(关键细节)nn.init.normal_(self.A.weight, std=0.01)  # A用小随机值初始化nn.init.zeros_(self.B.weight)  # B初始化为0,保证训练开始时LoRA不影响输出def forward(self, x):# LoRA的输出return self.B(self.A(x)) * self.scaling# ------------------------------
# 4. 带LoRA的单头自注意力
# ------------------------------
class SelfAttention_LoRA(nn.Module):def __init__(self, embed_dim=512, lora_rank=8, lora_alpha=16):super().__init__()self.embed_dim = embed_dim# 原始Q、K、V和输出层(冻结预训练参数)self.to_q = nn.Linear(embed_dim, embed_dim)self.to_k = nn.Linear(embed_dim, embed_dim)self.to_v = nn.Linear(embed_dim, embed_dim)self.to_out = nn.Linear(embed_dim, embed_dim)# 冻结原始权重(核心:不修改预训练参数)for param in [self.to_q, self.to_k, self.to_v, self.to_out]:for p in param.parameters():p.requires_grad = False# 插入LoRA适配器(仅这些参数可训练)self.lora_q = LoRALayer(embed_dim, embed_dim, lora_rank, lora_alpha)self.lora_k = LoRALayer(embed_dim, embed_dim, lora_rank, lora_alpha)self.lora_v = LoRALayer(embed_dim, embed_dim, lora_rank, lora_alpha)self.lora_out = LoRALayer(embed_dim, embed_dim, lora_rank, lora_alpha)def forward(self, x):batch_size, seq_len, _ = x.shape# 1. 计算Q、K、V:原始输出 + LoRA调整量(核心变化)q = self.to_q(x) + self.lora_q(x)  # 原始Q + LoRA对Q的调整k = self.to_k(x) + self.lora_k(x)  # 原始K + LoRA对K的调整v = self.to_v(x) + self.lora_v(x)  # 原始V + LoRA对V的调整# 2. 注意力计算步骤与原始完全相同(无变化)attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (self.embed_dim ** 0.5)attn_probs = F.softmax(attn_scores, dim=-1)attn_output = torch.matmul(attn_probs, v)# 3. 输出投影:原始输出 + LoRA调整量output = self.to_out(attn_output) + self.lora_out(attn_output)return output# ------------------------------
# 主流程:随机图像输入 + 前后对比
# ------------------------------
if __name__ == "__main__":# 1. 随机初始化图像张量 (batch_size=2, 3, 224, 224)batch = torch.randn(2, 3, 224, 224)  # 模拟2张224x224x3的图像print(f"输入图像形状: {batch.shape}")  # torch.Size([2, 3, 224, 224])# 2. 图像→特征映射proj = ImageProjection()features = proj(batch)print(f"映射后序列形状: {features.shape}")  # torch.Size([2, 196, 512])# 3. 原始自注意力计算original_attn = SelfAttention()original_output = original_attn(features)print(f"原始注意力输出形状: {original_output.shape}")  # torch.Size([2, 196, 512])# 4. LoRA自注意力计算lora_attn = SelfAttention_LoRA()lora_output = lora_attn(features)print(f"LoRA注意力输出形状: {lora_output.shape}")  # torch.Size([2, 196, 512])# 5. 参数对比original_params = sum(p.numel() for p in original_attn.parameters())lora_trainable_params = sum(p.numel() for p in lora_attn.parameters() if p.requires_grad)print(f"原始注意力总参数: {original_params:,}")  # 1,050,624print(f"LoRA可训练参数: {lora_trainable_params:,}")  # 32,768print(f"参数占比: {lora_trainable_params/original_params:.2%}")  # 3.12%

4、LoRA 层的 A/B 矩阵初始化

LoRA 论文里明确建议:

A 矩阵(in_dim → rank):正态分布随机初始化,标准差很小(如 std=0.01 或 std=0.02)。

B 矩阵(rank → out_dim)全 0 初始化

一、A 矩阵的初始化:小随机值初始化

A 矩阵的作用是将高维输入(如embed_dim=512)投影到低秩空间(rank=8),其初始化策略为:
使用均值为 0、小标准差(通常 0.01)的高斯分布随机初始化

代码实现(对应前文示例):

nn.init.normal_(self.A.weight, std=0.01)  # 均值默认0,标准差0.01
为什么这样设计?
  1. 避免初始干扰过大
    若 A 矩阵初始值过大,会导致 LoRA 的调整量(B(Ax))在训练初期就显著偏离原始模型输出,破坏预训练模型的基础能力。小随机值能保证初始时 LoRA 的影响微乎其微。

  2. 提供探索空间
    随机初始化确保 A 矩阵能学习到输入特征中与任务相关的低秩成分(如特定图像风格、文本语义的关键维度),而非固定偏向某一方向。

  3. 与梯度缩放匹配
    配合后续的scaling = alpha/rank(如alpha=16, rank=8时缩放因子为 2),小随机值经过缩放后能与原始模型输出的量级保持一致,避免梯度爆炸。

二、B 矩阵的初始化:零矩阵初始化

B 矩阵的作用是将低秩空间(rank=8)映射回高维输出(embed_dim=512),其初始化策略为:
完全初始化为零矩阵

代码实现(对应前文示例):

nn.init.zeros_(self.B.weight)  # 所有元素初始化为0
为什么这样设计?
  1. 保证初始输出一致性
    当 B 矩阵为零时,无论 A 矩阵输出什么,LoRA 的整体调整量(B(Ax) * scaling)都为零。这意味着:

    插入LoRA后的输出 = 原始模型输出 + 0 = 原始模型输出
    
     

    确保微调开始时,模型行为与预训练模型完全一致,避免因初始化引入偏差。

  2. 实现 “渐进式学习”
    训练初期,B 矩阵从 0 开始缓慢更新,LoRA 的调整量从小到大逐渐累积,模型会先保留预训练知识,再逐步学习任务特异性信息,减少过拟合风险。

  3. 稳定梯度起点
    零初始化使 B 矩阵的初始梯度更稳定。若 B 矩阵初始值非零,可能导致训练初期梯度波动过大,影响收敛。

三、初始化对训练过程的影响(直观示例)

假设输入特征x经过 A 矩阵后得到低秩特征a = A x(因 A 是小随机值,a的量级较小),此时:

  • 训练开始时,B=0 → LoRA 调整量为 0 → 输出 = 原始模型输出。
  • 训练第一步,B 矩阵开始更新(假设更新为B=ΔB) → 调整量为ΔB * a * scaling(量级较小)。
  • 随着训练进行,B 矩阵逐渐累积有效信息,调整量逐渐增大,最终学到任务所需的特定模式。

这种 “从 0 开始逐步学习” 的方式,完美平衡了 “保留预训练知识” 和 “学习新任务” 的需求。

四、与全量微调初始化的本质区别

全量微调中,权重初始化通常使用 Xavier 或 Kaiming 等方法,目的是让各层输入输出的方差保持一致;而 LoRA 的 A、B 矩阵初始化核心目标是 **“最小化对原始模型的初始干扰”**,这是参数高效微调特有的设计思路 —— 因为原始模型已经是训练好的 “优质起点”,无需重新初始化,只需在此基础上 “小修小补”。

总结

LoRA 的初始化策略是其成功的关键细节之一:

  • A 矩阵用小随机值:提供探索空间,避免初始干扰过大;
  • B 矩阵用零初始化:保证初始输出与原始模型一致,实现渐进式学习。

这种设计使 LoRA 在微调时既能高效适配新任务,又能最大程度保留预训练模型的通用能力。

http://www.lryc.cn/news/614713.html

相关文章:

  • 2025年渗透测试面试题总结-07(题目+回答)
  • 【设计模式】访问者模式模式
  • Chrome DevTools Protocol 开启协议监视器
  • flutter开发(一)flutter命令行工具
  • SVM实战:从线性可分到高维映射再到实战演练
  • 【同余最短路】P2371 [国家集训队] 墨墨的等式|省选-
  • 在 Git 中,将本地分支的修改提交到主分支
  • 广东省省考备考(第七十天8.8)——言语、判断推理(强化训练)
  • ubuntu 22.04 使用yaml文件 修改静态ip
  • 开发板RK3568和stm32的异同:
  • Redis对象编码
  • 【Bellman-Ford】High Score
  • 荣耀秋招启动
  • Sum of Four Values(sorting and searching)
  • 两个函数 quantize() 和 dequantize() 可用于对不同的位数进行量化实验
  • 力扣-189.轮转数组
  • 优选算法 力扣 15. 三数之和 双指针降低时间复杂度 C++题解 每日一题
  • 深入解析 Seaborn:数据可视化的优雅利器
  • 自定义上传本地文件夹到七牛云
  • 虚拟机Ubuntu图形化界面root用户登录错误
  • 使用pybind11封装C++API
  • Shell、Python对比
  • 要写新项目了,运行老Django项目找找记忆先
  • C++中的继承:从基础到复杂
  • 飞算JavaAI深度解析:专为Java生态而生的智能引擎
  • 安全引导功能及ATF的启动过程(四)
  • 巧妙实现Ethercat转Profinet协议网关匹配光伏电站
  • 「ECG信号处理——(22)Pan-Tompkins Findpeak 阈值检测 差分阈值算法——三种R波检测算法对比分析」2025年8月8日
  • C语言编译流程讲解
  • 【Open3D】基础操作之三维数据结构的高效组织和管理