当前位置: 首页 > news >正文

机器学习(李宏毅)——Transformer

一、前言

本文章作为学习2023年《李宏毅机器学习课程》的笔记,感谢台湾大学李宏毅教授的课程,respect!!!
读这篇文章必须先了解self-attention,可参阅我上一篇。

二、大纲

  • Transformer问世
  • 原理剖析
  • 模型训练

三、Transformer问世

2017 年在文章《Attention Is All You Need》被提出的。应用于seq2seq模型,当时直接轰动。

四、原理剖析

两部分组成:Encoder 和 Decoder

  • Encoder 结构
    接下来从大到小一层层剥开:

剥一下:
输入一排向量,输出一排向量
在这里插入图片描述
剥两下:
Encoder 由多个Block组成,串联起来
在这里插入图片描述
剥三下:
Block装的是啥?原来是Self-attention!
在这里插入图片描述
剥四下:
Self-attention原来加入了residual和Layer Normal,至此剥完了。
在这里插入图片描述

说明:
上图自底向上看,关键点:
1、residual结构,输入接到输出送入下一层,残差结构;
2、Layer Normal,具体如下图:
在这里插入图片描述
算出标准差和均值后,套用公式计算即可。

以上就是Encoder的全部了!
论文中是这么画图表达的:
在这里插入图片描述
注:Positional Encoding是self-attention的位置资讯。

  • Decoder 结构
    有两种方法生成输出:Auto Regressive 和 Non Auto Regressive。

Auto Regressive
在这里插入图片描述
给个START符号,把本次输出当做是下一次的输入,依序进行下去。

Non Auto Regressive
在这里插入图片描述
输入是一排的START符号,一下子梭哈突出一排输出。

Encoder结构长啥样?

接下来看下结构长啥样,先遮住不一样的部分,其他部分结构基本一致,只不过这里用上了Masked Multi-Head Attention
在这里插入图片描述

Masked Multi-Head Attention
啥是Masked Multi-Head Attention?Masked有啥含义?
可以直接理解为单向的Multi-Head Attention,而且是从左边开始:
在这里插入图片描述
说明:这也很好理解,右边的字符都还没输出出来怎么做运算,因此只能是已经吐出来的左边的内容做self-attention,这就是masked的含义。

遮住的部分是啥?(cross attention)

最后这边遮住的部分到底是啥玩意?
别想太复杂,就还是self-attention。
corss的意思就是v,k来自Encoder,q来自Decoder,仅此而已。
在这里插入图片描述
在这里插入图片描述
其实也好理解,Decoder是去还原结果的,那可不得抽下Encoder编码时候的资讯和上下文语义信息才能还原,缺一不可。
比喻下,前者让输出紧扣题意,后者让其说人话。

小结
至此,Transformer的结构就阐述完了,无非就是Encoder + Decoder,建议自己在草稿纸上画画能够加深印象。

五、模型训练

transformer的模型训练用的还是cross entropy。
在这里插入图片描述
实战过程中的tips

  • copy mechanism
    例如:
    Machine Translation(机器翻译),可能使用原文复制这个技能对于模型而言比较容易,毕竟它不需要创造新词汇了嘛,这就是copy mechanism。
  • Guided Attention
    在这里插入图片描述

意思就是不要乱Attention,有的放矢地让模型做attention。

  • Beam Search
    在这里插入图片描述
    基本思想就是不要步步好,有可能短期不好但是长期更好。说的和人生似的。
    如果模型需要有点创造力,不适合用此方法,这是实做后的结论。

训练过程记得让模型看些负样本,不至于模型一步错步步错,(schedule sampling方法)。

五、小结

最基本的掌握好Encoder和Decoder就很可以了,其他的在实做过程中遇到问题再问问AI工具。

http://www.lryc.cn/news/534893.html

相关文章:

  • React进阶之React状态管理CRA
  • 攻克AWS认证机器学习工程师(AWS Certified Machine Learning Engineer) - 助理级别认证:我的成功路线图
  • 前端开发环境
  • Web自动化测试—测试用例流程设计
  • HTML全局属性与Meta元信息详解:优化网页的灵魂
  • day001 折半查找/二分查找
  • Linux 资源监控:优化与跟踪系统性能
  • java安全中的类加载
  • Node.js调用DeepSeek Api 实现本地智能聊天的简单应用
  • 分布式服务框架 如何设计一个更合理的协议
  • Unity使用iTextSharp导出PDF-02基础结构及设置中文字体
  • Kafka因文件句柄数过多导致挂掉的排查与解决
  • 【LeetCode Hot100 多维动态规划】最小路径和、最长回文子串、最长公共子序列、编辑距离
  • PRC框架-Dubbo
  • 智能检测摄像头模块在客流统计中的应用
  • [LLM面试题] 指示微调(Prompt-tuning)与 Prefix-tuning区别
  • 【CubeMX+STM32】SD卡 U盘文件系统 USB+FATFS
  • 在JVM的栈(虚拟机栈)中,除了栈帧(Stack Frame)还有什么?
  • # 解析Excel文件:处理Excel xlsx file not supported错误 [特殊字符]
  • 图片下载不下来?即便点了另存为也无法下载?两种方法教你百分之百下载下来
  • Unity项目实战-Player玩家控制脚本实现
  • CP AUTOSAR标准之ICUDriver(AUTOSAR_SWS_ICUDriver)(更新中……)
  • Python3 ImportError: cannot import name ‘XXX‘ from ‘XXX‘
  • [学习笔记] Kotlin Compose-Multiplatform
  • 【R语言】t检验
  • flutter ListView Item复用源码解析
  • Spring Boot 配置 Mybatis 读写分离
  • 网络初识-
  • DNS污染:网络世界的“隐形劫持”与防御
  • MQTT(Message Queuing Telemetry Transport)协议(三)