当前位置: 首页 > news >正文

PyTorch生成式人工智能(18)——循环神经网络详解与实现

PyTorch生成式人工智能(18)——循环神经网络详解与实现

    • 0. 前言
    • 1. 文本生成的挑战
    • 2. 循环神经网络
      • 2.1 文本数据
      • 2.2 循环神经网络原理
    • 3. 长短期记忆网络
    • 3. 自然语言处理基础
      • 3.1 分词
      • 3.2 词嵌入
      • 3.3 词嵌入在自然语言处理中的应用
    • 小结
    • 系列链接

0. 前言

我们已经学习了如何生成数字和图像等内容。从本节开始,我们将主要聚焦于文本生成。人类语言极其复杂且充满细微差别,不仅仅涉及语法和词汇的理解,还包括上下文、语气和文化背景等。成功生成连贯且语境适当的文本是一项重大挑战,需要深入理解和处理语言。

1. 文本生成的挑战

人类主要通过语言进行交流,能够生成语言文本的人工智能可以更自然地与用户互动,使技术变得更加易于使用。文本生成有广泛的应用,包括自动化客户服务回复、创作文章和电影剧本创作、帮助创意写作,甚至构建个人助手。
在本节中,我们将学习如何解决文本生成建模中的三个主要挑战。首先,文本是序列数据,由按特定顺序排列的数据点组成,每个数据点按顺序排列,以反映数据内部的顺序和相互依赖性。由于序列的顺序敏感性,预测序列结果具有挑战性,改变元素的顺序会改变它们的含义。第二,文本存在长程依赖性,文本中某一部分的含义可能依赖于文本中更早出现的元素,理解和建模这些长程依赖性对于生成连贯的文本至关重要。最后,人类语言具有歧义性和上下文依赖性。训练模型理解语言的细微差别、习语和文化背景,生成上下文准确的文本非常具有挑战。
本节将介绍一种专门用于处理序列数据(如文本或时间序列)的神经网络:循环神经网络 (Recurrent Neural Network, RNN)。传统的神经网络,如全连接网络,会独立处理每个输入。这意味着网络处理每个输入时,并不考虑不同输入之间的关系或顺序。而 RNN 专门设计用于处理序列数据,在 RNN 中,给定时间步的输出不仅依赖于当前输入,还依赖于之前的输入。这使得 RNN 能够保持某种形式的记忆,捕捉之前时间步的信息,从而影响当前输入的处理。
这种序列处理使得 RNN 非常适用于顺序数据任务,如语言建模,目标是根据前面的单词预测句子中的下一个单词。我们将重点介绍 RNN 的变种——长短期记忆 (Long Short-Term Memory, LSTM) 网络,LSTM 能够识别文本等序列数据中的短期和长期数据模式。LSTM 模型使用隐藏状态来捕捉之前时间步的信息。

2. 循环神经网络

我们已经简要了解了生成文本的复杂性,特别是在需要保持连贯性和上下文相关性时。本节将更深入地探讨这些挑战,并探索循环神经网络 (Recurrent Neural Network, RNN) 架构。我们将介绍 RNN 在生成文本任务中的优势,以及它们的局限性(即它们被 Transformer 所取代的原因)。
RNN 的设计初衷就是用于处理序列数据,使其能够胜任文本生成这一本质上具有序列性质的任务。RNN 利用隐藏状态的记忆形式来捕捉并保留序列早期部分的信息,这对于在序列进行时保持上下文和理解依赖关系至关重要。

2.1 文本数据

文本是序列数据的典型例子,序列数据指元素的顺序至关重要的数据。这种结构意味着各个元素相互之间的位置具有重要意义,通常传达了理解数据所必需的关键信息。常见的序列数据包括时间序列(如股票价格)、文本内容(如句子)和音乐作品(一系列音符)。
生成文本的主要的挑战在于建模句子中单词的顺序,改变顺序会极大地改变句子的含义。例如,在句子“小明在足球比赛中击败了小华”中,若将“小明”和“小华”交换位置,尽管使用了相同的词语,但句子的含义完全颠倒了。此外,文本生成在处理长程依赖性和解决歧义问题时也面临挑战。
在本节中,我们将探讨如何使用 RNN 解决这些挑战。虽然这种方法并不完美,但为其他更先进技术奠定了基础。通过本节的学习,我们将了解如何处理单词顺序、应对长程依赖性以及应对文本的歧义性,从而掌握文本生成的基本技术。本节的学习是深入了解更复杂方法的基础,掌握自然语言处理 (Natuarl Language Processing, NLP) 中的关键技术,如文本分词、词嵌入和序列到序列预测。

2.2 循环神经网络原理

RNN 是一种专门设计用于识别数据序列中模式的人工神经网络,适用于处理文本、音乐或股票价格等数据序列。与传统的神经网络不同,传统网络独立地处理输入数据,RNN 网络内部包含循环结构,从而使信息得以持续传递。
生成文本的一个挑战是如何根据所有前面的单词预测下一个单词,以便捕捉长程依赖性和上下文意义。RNN 接受的不仅仅是独立的输入,而是作为一个序列(例如句子中的单词)。在每个时间步中,预测不仅依赖于当前输入,还通过隐藏状态以总结的形式考虑所有之前的输入。以短语 “a frog has four legs” 为例。在第一个时间步中,使用单词 “a” 预测第二个单词 “frog”。在第二个时间步中,利用 “a” 和 “frog” 来预测下一个单词。当我们预测最后一个单词时,需要使用前四个单词 “a frog has four”。
RNN 的一个关键特点是隐藏状态,它捕获了序列中所有先前元素的信息,这对于网络有效处理和生成序列数据至关重要。RNN 的工作原理及其序列处理方式如下图所示,图中展示了一个循环神经元层随时间展开的过程。
RNN架构
RNN 中的隐藏状态在捕捉所有时间步的信息中起着关键作用,这使得 RNN 不仅能够基于当前输入 x(t)x(t)x(t),还能够利用从所有先前输入 x0,x1,…,xt−1x_0,x_1,…,x_{t−1}x0,x1,,xt1 中积累的知识进行预测。这一特性使 RNN 能够理解时间依赖性,可以从输入序列中捕捉上下文,这对于语言建模等任务至关重要,因为句子中的前一个单词为预测下一个单词奠定了基础。
我们还可以通过方程式将 RNN 描述为计算图。时刻 ttt 处的 RNN 的内部状态由隐藏向量 hth_tht 的值给出,该隐藏向量是权重矩阵 www 和时刻 t−1t-1t1 处的隐藏状态 ht−1h_{t-1}ht1 的乘积,再加上权重矩阵 uuu 与时刻 ttt 处的输入 xtx_txt 的乘积,然后通过 tanh 激活函数传递。选择 tanh 而不是 sigmoid 等其他激活函数,是因为在实践中 tanh 更有效,并有助于解决梯度消失问题。
为了方便,在描述不同类型的 RNN 结构的方程中,我们省略了显式的偏置项,将其合并到矩阵中。例如,以下 n 维空间中的直线方程,其中 w1w_1w1wnw_nwn 指的是直线在每个维度上的系数,偏置 bbb 指的是每个维度上的 yyy 截距:
y=w1x1+w2x2+⋯+wnxn+by=w_1x_1+w_2x_2+\cdots+w_nx_n+b y=w1x1+w2x2++wnxn+b
我们可以将方程重写为矩阵形式如下:
y=wx+by=wx+b y=wx+b
其中,www 是形状为 (m,n)(m, n)(m,n) 的矩阵,bbb 是形状为 (m,1)(m, 1)(m,1) 的向量,mmm 是数据集中的样本数,nnn 是每个样本对应的特征数。等效地,我们可以通过将 bbb 向量视为 www 的单位特征列,将其折叠到矩阵 www 中来消除向量 bbb
y=w1x1+w2x2+⋯+wnxn+w0(1)=w′xy=w_1x_1+w_2x_2+\cdots+w_nx_n+w_0(1)=w'x y=w1x1+w2x2++wnxn+w0(1)=wx
其中,w′w'w 是形状为 (m,n+1)(m, n+1)(m,n+1) 的矩阵,最后一列包含 bbb 的值,这样得到的形式更为紧凑,也更容易理解和记忆。
在时刻 ttt 处的输出向量 oto_tot 是权重矩阵 vvv 和隐藏状态 hth_tht 的乘积经过 softmax 激活得到的,得到的输出向量为输出概率:
ht=tanh(wht−1+uxt)ot=softmax(vht)h_t= tanh(wh_{t-1} +ux_t)\\ o_t=softmax(vh_t) ht=tanh(wht1+uxt)ot=softmax(vht)

3. 长短期记忆网络

尽管标准的 RNN 可以处理短期依赖性,但在处理文本中的长期依赖性时却表现不佳。这源于梯度消失问题,即在长序列中,梯度逐渐减小,阻碍了模型学习长距离关系的能力。为了缓解这一问题,提出了一系列 RNN 的变体,如长短期记忆 (Long Short-Term Memory, LSTM) 网络。
LSTM 网络由 LSTM 单元组成,每个单元的结构比标准的 RNN 神经元更为复杂。单元状态是 LSTM 的关键创新,它贯穿整个 LSTM 单元链,能够在网络中传递相关信息。通过向单元状态中添加或删除信息,LSTM 能够捕捉长期依赖性,并长期记住信息。这使得它在语言建模和文本生成等任务中更为高效。下图展示了在时刻 ttt 时应用于隐藏状态的变换过程:
LSTM 单元

整体来看似乎很复杂,逐个组件分解便于理解。图顶部的横线代表单元的内部状态(也称记忆) ccc。底部的横线代表隐藏状态 hhh,而 iiifffooo 门则是 LSTM 处理梯度消失问题的机制。在训练期间,LSTM 学习这些门的参数。
理解 LSTM 单元内部各个门工作机制的另一种方式是利用数学方程,这些方程描述了如何从前一时时刻的隐藏状态 ht−1h_{t-1}ht1 计算时刻 ttt 的隐藏状态 hth_tht 的值。总的来说,基于方程的描述往往更清晰、更简洁。表示 LSTM 的方程组如下:
i=σ(wiht−1+uixt+vi)f=σ(wfht−1+ufxt+vf)o=σ(woht−1+uoxt+vo)g=tanhwght−1+ugxtct=(f∗ct−1)+(g∗i)ht=tanh(ct)∗oi=\sigma(w_ih_{t-1}+u_ix_t+v_i)\\ f=\sigma(w_fh_{t-1}+u_fx_t+v_f)\\ o=\sigma(w_oh_{t-1}+u_ox_t+v_o)\\ g=\text{tanh}w_gh_{t-1}+u_gx_t\\ c_t=(f\ast c_{t-1})+(g\ast i)\\ h_t=\text{tanh}(c_t)\ast o i=σ(wiht1+uixt+vi)f=σ(wfht1+ufxt+vf)o=σ(woht1+uoxt+vo)g=tanhwght1+ugxtct=(fct1)+(gi)ht=tanh(ct)o
其中,iiifffooo 分别表示输入门 (input gate)、遗忘门 (forget gate) 和输出门 (output gate)。它们使用相同的方程计算,但使用不同的参数矩阵 (wiw_iwiuiu_iuiwfw_fwfufu_fufwow_owouou_ouo)。sigmoid 函数将这些门的输出限制在 01 之间,因此生成的输出向量可以与另一个向量逐元素相乘,以定义第二个向量能通过第一个向量的程度。
遗忘门定义了之前隐藏状态 ht−1h_t-1ht1 通过的量,输入门定义了新计算的当前输入 xtx_txt 的状态通过的量,而输出门定义了将内部状态暴露给下一层的量。内部隐藏状态 ggg 是基于当前输入 xtx_txt 和前一个隐藏状态 ht−1h_{t-1}ht1 计算的。需要注意的是,ggg 的方程与 SimpleRNN 的方程相同,但在 LSTM 中会通过输入门 iii 的输出进行调节。
给定 iiifffoooggg,可以计算时间刻 ttt 的单元状态 ctc_tct,方法是将时刻 t−1t-1t1 处的单元状态 ct−1c_{t-1}ct1 乘以遗忘门 fff 的值,再加上状态 ggg 乘以输入门 iii 的值。这本质上是将前一个单元状态与新输入结合起来的方法,将遗忘门设为 0 可以忽略旧状态,将输入门设为 0 可以忽略新计算出的状态。最后,时刻 ttt 处的隐藏状态 hth_tht 计算为时刻 ttt 处的单元状态 ctc_tct 乘以输出门 ooo
然而,需要注意的是,即使是像 LSTM 这样的高级 RNN 变种,在捕捉序列数据中的极长距离依赖关系时也会遇到困难。我们将在下一节深入讨论这些挑战,并提供解决方案,继续探讨有效处理和生成序列数据的复杂模型。

3. 自然语言处理基础

深度学习模型,包括 LSTM 模型和 Transformer 模型,无法直接处理原始文本,因为它们是为处理数值数据而设计的,数据通常以向量或矩阵的形式存在。神经网络的处理和学习能力基于数学运算,如加法、乘法和激活函数。因此,首先需要将文本分解为更小、更易于管理的元素,这些元素称为 token (词元),token 可以是单个字符、单词或子词单元。
NLP 任务中的下一关键步骤是将这些词元转换为数值表示,这种转换是将它们输入深度神经网络的必要步骤,也是训练模型的基本部分。
在本节中,我们将讨论不同的分词方法,并分析它们的优缺点。此外,还将了解将词元转换为向量表示的过程,这一方法称为词嵌入 (word embedding),这种技术对于以深度学习模型能够有效利用的格式捕捉语言的意义至关重要。

3.1 分词

分词是将文本划分为较小部分的过程,这些部分被称为 token,可以是单词、字符、符号或其他有意义的单位。分词的主要目的是简化文本数据的分析和处理过程。
总体而言,有三种常见的分词方法。第一种是字符分词,即将文本划分为其组成字符。这种方法通常用于形态学结构复杂的语言,如土耳其语或芬兰语,在这些语言中,单词的意义可能随着字符的微小变化而发生显著变化。以英语短语 “It is unbelievably good!” 为例,可以分解成如下的单个字符:['I', 't', ' ', 'i', 's', ' ', 'u', 'n', 'b', 'e', 'l', 'i', 'e', 'v', 'a', 'b', 'l', 'y', ' ', 'g', 'o', 'o', 'd', '!']。字符分词的一个关键优点是词元的数量较少,能够显著减少深度学习模型中的参数,从而加快训练速度,提高效率。但主要的缺点是单个字符通常没有明确的语义,使得机器学习模型难以从字符序列中提取有意义的信息:

text="It is unbelievably good!"
tokens=list(text)
print(tokens)

第二种方法是单词分词,即将文本拆分成单个的单词和标点符号。通常用于唯一单词数量不是特别大的情况。例如,短语 “It is unbelievably good!” 会被分解成五个词元:['It', 'is', 'unbelievably', 'good', '!']。这种方法的主要优点是每个单词本身就携带语义信息,使得模型更容易理解文本。然而,缺点在于独特词元的数量显著增加,进而增加了深度学习模型中的参数数量,可能导致训练过程变慢,效率降低:

text="It is unbelievably good!"
text=text.replace("!"," !")
tokens=text.split(" ")
print(tokens)

第三种方法是子词分词。这种方法是 NLP 中的一个关键概念,它将文本分解为更小、更有意义的组成部分,称为子词。例如,短语 “It is unbelievably good!” 将被分解为 ['It', 'is', 'un', 'believ', 'ably', 'good', '!'] 等词元。包含 GPT 等在内的多数先进的语言模型,都使用子词分词。
子词分词在传统分词技术之间找到了平衡,传统的分词方法通常将文本分割为单独的单词或字符。基于单词的分词虽然可以捕捉更多的意义,但会导致词汇表非常庞大,而基于字符的分词会生成较小的词汇表,但每个词元的语义价值较低。子词分词通过将常用词保留为完整词元,同时将不常见或复杂的词分解为子词单元,有效缓解了这些问题,这种技术特别适用于词汇量大的语言或词形变化丰富的语言。通过采用子词分词,词汇表规模大幅缩小,从而提高了语言处理任务的效率和效果,尤其是在处理多样化的语言结构时。

3.2 词嵌入

词嵌入是一种将词元转化为紧凑向量表示的方法,能够捕捉语义信息及其相互关系。词嵌入技术在 NLP 中至关重要,尤其是因为深度神经网络(如 LSTMTransformer 等模型)需要数值输入。
传统上,词元会通过独热编码 (one-hot encoding) 转换为数字,然后输入到 NLP 模型中。在独热编码中,每个词元由一个向量表示,其中只有一个元素是 ‘1’,其余元素是 ‘0’。例如,小说《安娜·卡列尼娜》文本中有 12,778 个独特的单词词元,每个词元会表示为一个 12,778 维的向量。因此,短语 “happy families are all alike” 会表示为一个 5 × 12,778 的矩阵,其中 5 表示词元的数量。然而,由于维度过大,这种表示方式效率极低,导致参数数量大幅增加,从而可能影响训练速度和效率。
LSTMTransformer 和其他先进的 NLP 模型通过词嵌入解决了这一问题。词嵌入使用连续的、低维的向量(例如 128 维向量)取代了庞大的独热向量。因此,短语 “happy families are all alike” 经过词嵌入后,会被表示为一个更紧凑的 5 × 128 的矩阵。这种简化的表示大幅度减少了模型的复杂性,提高了训练效率。
词嵌入不仅通过将词元压缩到低维空间来降低复杂性,还能有效捕捉上下文和词元之间细微的语义关系,这是独热编码等简单表示方法无法实现的。这是因为,在独热编码中,所有词元在向量空间中的距离是相同的,无法体现语义相似性;而在词嵌入中,语义相似的标记在嵌入空间中的向量通常彼此接近。词嵌入是通过从训练数据的文本中学习得到的,由此产生的向量能够捕捉上下文信息,在相似上下文中的词元即使没有显式关联,也会有相似的嵌入表示。

3.3 词嵌入在自然语言处理中的应用

词嵌入是自然语言处理中表示词元的一种强大方法,相比传统的独热编码,它在捕捉上下文和词元之间的语义关系方面具有显著优势。
独热编码将词元表示为稀疏向量,其维度等于词汇表的大小,每个词元由一个向量表示,其中除了与该词元对应的索引位置上的元素为 1 外,其余位置上的元素都是 0。相比之下,词嵌入将词元表示为低维的稠密向量(例如,使用 128 维向量或 256 维向量)。这种稠密表示更加高效,并且能够捕捉更多的信息。
具体而言,在独热编码中,所有词元在向量空间中的距离是相同的,无法体现词元之间的相似性。然而,在词嵌入中,语义相似的词元在嵌入空间中的向量距离较近。例如,“king” (国王)和 “queen” (女王)这两个词具有有相似的嵌入,反映它们之间的语义关系。
词嵌入是从训练数据的文本中学习得到的。嵌入过程利用词元出现的上下文来学习其嵌入表示,这意味着生成的向量能够捕捉上下文信息。出现在相似上下文中的词元,即使它们没有显式关联,也会有相似的嵌入表示。
总体而言,词嵌入提供了更细致和高效的词元表示,能够捕捉语义关系和上下文信息,使其相比独热编码更适用于 NLP 任务。
PyTorch 框架中,词嵌入是通过将索引传递给线性层来实现的,线性层会将索引压缩到低维空间。也就是说,将索引传递给 nn.Embedding() 层时,它会在嵌入矩阵中查找对应的行,并返回该索引的嵌入向量,从而避免了创建可能非常庞大的独热向量。嵌入层的权重并不是预定义的,而是在训练过程中学习得到的,这个学习过程使得模型能够根据训练数据来优化对词元语义的理解,从而在神经网络中得到更细致和上下文感知的语言表示。这种方法显著提升了模型高效处理和解释语言数据的能力。

小结

  • 循环神经网络 (Recurrent Neural Network, RNN)是一种专门设计用于识别序列数据(如文本、音乐或股票价格)模式的人工神经网络。与传统神经网络独立处理输入不同,RNN 具有循环结构,使得信息能够持久传递。长短期记忆 (Long Short-Term Memory, LSTM) 网络是 RNN 的改进版本
  • 三种主要的分词方法包括:字符分词,将文本分解为单个字符;单词分词,将文本拆分为单个单词;子词分词,将单词分解为较小的、有意义的子词单元
  • 词嵌入是一种将单词转换为紧凑向量表示的方法,能够捕捉单词的语义信息及相互关系。这一技术在 NLP 中至关重要,因为在深度神经网络模型(如 LSTMTransformer )需要数值输入

系列链接

PyTorch生成式人工智能实战:从零打造创意引擎
PyTorch生成式人工智能(1)——神经网络与模型训练过程详解
PyTorch生成式人工智能(2)——PyTorch基础
PyTorch生成式人工智能(3)——使用PyTorch构建神经网络
PyTorch生成式人工智能(4)——卷积神经网络详解
PyTorch生成式人工智能(5)——分类任务详解
PyTorch生成式人工智能(6)——生成模型(Generative Model)详解
PyTorch生成式人工智能(7)——生成对抗网络实践详解
PyTorch生成式人工智能(8)——深度卷积生成对抗网络
PyTorch生成式人工智能(9)——Pix2Pix详解与实现
PyTorch生成式人工智能(10)——CyclelGAN详解与实现
PyTorch生成式人工智能(11)——神经风格迁移
PyTorch生成式人工智能(12)——StyleGAN详解与实现
PyTorch生成式人工智能(13)——WGAN详解与实现
PyTorch生成式人工智能(14)——条件生成对抗网络(conditional GAN,cGAN)
PyTorch生成式人工智能(15)——自注意力生成对抗网络(Self-Attention GAN, SAGAN)
PyTorch生成式人工智能(16)——自编码器(AutoEncoder)详解
PyTorch生成式人工智能(17)——变分自编码器详解与实现

http://www.lryc.cn/news/591028.html

相关文章:

  • 【Linux基础知识系列】第五十一篇 - Linux文件命名规范与格式
  • Mac 安装及使用sdkman指南
  • Java 大视界 -- Java 大数据在智能交通智能公交站台乘客流量预测与服务优化中的应用(349)
  • Flask+LayUI开发手记(十一):选项集合的数据库扩展类
  • Java 集合框架详解:Collection 接口全解析,从基础到实战
  • 【LeetCode 热题 100】108. 将有序数组转换为二叉搜索树
  • 【Redis 】看门狗:分布式锁的自动续期
  • 如何用Kaggle免费GPU
  • [yotroy.cool] Git 历史迁移笔记:将 Git 项目嵌入另一个仓库子目录中(保留提交记录)
  • 语雀编辑器内双击回车插入当前时间js脚本
  • 【WRFDA第六期】WRFDA 输出文件详述
  • R语言基础| 基本图形绘制(条形图、堆积图、分组图、填充条形图、均值条形图)
  • Spring AI之Prompt开发
  • Web攻防-PHP反序列化Phar文件类CLI框架类PHPGGC生成器TPYiiLaravel
  • Cursor开发步骤
  • 【C++指南】C++ list容器完全解读(四):反向迭代器的巧妙实现
  • 113:路径总和 II
  • Java学习--JVM(2)
  • 基于FPGA的IIC控制EEPROM读写(2)
  • AI算法之图像识别与分类
  • 深入理解Java中的Collections.max()方法
  • 贪心算法(排序)
  • GLM(General Language Model,通用语言模型)
  • 2020717零碎写写
  • 学习OpenCV---显示图片
  • Java集合框架中List常见问题
  • Python爬虫实战:Requests与Selenium详解
  • ESLint 完整功能介绍和完整使用示例演示
  • 产品经理如何描述用户故事
  • Rocky Linux 9 源码包安装php7