当前位置: 首页 > news >正文

基于Python的自然语言处理系列(2):Word2Vec(负采样)

        在本系列的第二篇文章中,我们将继续探讨Word2Vec模型,这次重点介绍负采样(Negative Sampling)技术。负采样是一种优化Skip-gram模型训练效率的技术,它能在大规模语料库中显著减少计算复杂度。接下来,我们将通过详细的代码实现和理论讲解,帮助你理解负采样的工作原理及其在Word2Vec中的应用。

1. Word2Vec(负采样)原理

1.1 负采样的背景

        在Word2Vec的Skip-gram模型中,我们的目标是通过给定的中心词预测其上下文词。然而,当词汇表非常大时,计算所有词的预测概率会变得非常耗时。为了解决这个问题,负采样技术被引入。

1.2 负采样的工作原理

        负采样通过从词汇表中随机选择一些词作为负样本来简化训练过程。具体来说,除了正样本(即真实的上下文词),我们还为每个正样本选择若干个负样本。模型的目标是最大化正样本的预测概率,同时最小化负样本的预测概率。这样,训练过程只需要考虑部分词汇,从而减少了计算量。

2. Word2Vec(负采样)实现

        我们将通过以下步骤来实现带有负采样的Word2Vec模型:

2.1 定义简单数据集

        首先,我们定义一个简单的语料库来演示负采样的应用。

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F# 定义语料库
corpus = ["apple banana fruit", "banana apple fruit", "banana fruit apple","dog cat animal", "cat animal dog", "cat dog animal"]corpus = [sent.split(" ") for sent in corpus]
print(corpus)

2.2 数据预处理

        获取词序列和唯一词汇,并进行数值化处理。

# 获取词汇表
flatten = lambda l: [item for sublist in l for item in sublist]
vocab = list(set(flatten(corpus)))
print(vocab)# 数值化
word2index = {w: i for i, w in enumerate(vocab)}
print(word2index)# 词汇表大小
voc_size = len(vocab)
print(voc_size)# 添加UNK标记
vocab.append('<UNK>')
word2index['<UNK>'] = 0
index2word = {v: k for k, v in word2index.items()}

2.3 准备训练数据

        定义一个函数用于生成Skip-gram模型的训练数据。

def random_batch(batch_size, word_sequence):skip_grams = []for sequence in word_sequence:for i, word in enumerate(sequence):context = [sequence[j] for j in range(max(0, i - 1), min(len(sequence), i + 2)) if j != i]for ctx_word in context:skip_grams.append((word, ctx_word))return skip_grams

2.4 负采样

        实现负采样的训练过程。

class Word2Vec(nn.Module):def __init__(self, vocab_size, embedding_dim):super(Word2Vec, self).__init__()self.in_embed = nn.Embedding(vocab_size, embedding_dim)self.out_embed = nn.Embedding(vocab_size, embedding_dim)self.in_embed.weight.data.uniform_(-1, 1)self.out_embed.weight.data.uniform_(-1, 1)def forward(self, center_word, context_word):in_embeds = self.in_embed(center_word)out_embeds = self.out_embed(context_word)scores = torch.matmul(in_embeds, out_embeds.t())return scores# Initialize model
embedding_dim = 10
model = Word2Vec(voc_size, embedding_dim)
optimizer = optim.SGD(model.parameters(), lr=0.01)

2.5 训练模型

        进行模型训练,并应用负采样技术来优化模型。

def train_word2vec(model, skip_grams, epochs=10):for epoch in range(epochs):total_loss = 0for center, context in skip_grams:center_idx = torch.tensor([word2index[center]], dtype=torch.long)context_idx = torch.tensor([word2index[context]], dtype=torch.long)optimizer.zero_grad()scores = model(center_idx, context_idx)target = torch.tensor([1], dtype=torch.float32)loss = F.binary_cross_entropy_with_logits(scores.squeeze(), target)loss.backward()optimizer.step()total_loss += loss.item()print(f'Epoch {epoch + 1}, Loss: {total_loss}')# Prepare skip-gram pairs
skip_grams = random_batch(10, corpus)
train_word2vec(model, skip_grams)

结语

        在本篇文章中,我们详细探讨了Word2Vec模型中的负采样技术,并通过代码实现展示了如何在Python中应用这一技术来优化Skip-gram模型。负采样通过减少计算量,提高了模型的训练效率,使得在大规模数据集上的训练变得可行。

        在下一篇文章中,我们将继续探讨另一种词向量表示方法——GloVe(Global Vectors for Word Representation)。敬请期待!

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

http://www.lryc.cn/news/433348.html

相关文章:

  • 每日一题|牛客竞赛|四舍五入|字符串+贪心+模拟
  • 大数据之Flink(六)
  • 设计模式学习[5]---装饰模式
  • 3.C_数据结构_栈
  • Debian11安装DolphinScheduler
  • C语言深度剖析--不定期更新的第五弹
  • python之事务
  • 文件加密软件都有哪些?推荐6款文件加密工具
  • Docker中的容器内部无法使用vi命令怎么办?
  • 【Linux系统编程】TCP实现--socket
  • 企业微信hook协议接口,聚合群聊客户管理工具开发
  • Selenium集成Sikuli基于图像识别的自动化测试
  • 【STM32实物】基于STM32设计的智能仓储管理系统(程序代码电路原理图实物图讲解视频设计文档等)——文末资料下载
  • libtool 中的 .la 文件说明
  • NLP-transformer学习:(6)dataset 加载与调用
  • 数据库系统 第43节 数据库复制
  • LabVIEW FIFO详解
  • 如何验证VMWare WorkStation的安装?
  • 论文阅读:AutoDIR Automatic All-in-One Image Restoration with Latent Diffusion
  • C++ | Leetcode C++题解之第392题判断子序列
  • 操作系统概述(三、虚拟化)
  • 基于ARM芯片与OpenCV的工业分拣机器人项目设计与实现流程详解
  • UNITY UI简易反向遮罩
  • 牛客周赛59(A,B,C,D,E二维循环移位,F范德蒙德卷积)
  • C语言中的隐型计算
  • ffmpeg面向对象-待定
  • 大厂嵌入式数字信号处理器(DSP)面试题及参考答案
  • GC-分代收集器
  • C++从入门到起飞之——priority_queue(优先级队列) 全方位剖析!
  • [数据集][目标检测]西红柿缺陷检测数据集VOC+YOLO格式17318张3类别