当前位置: 首页 > news >正文

pytorch基于GloVe实现的词嵌入

PyTorch 实现 GloVe(Global Vectors for Word Representation) 的完整代码,使用 中文语料 进行训练,包括 共现矩阵构建、模型定义、训练和测试


 1. GloVe 介绍

基于词的共现信息(不像 Word2Vec 使用滑动窗口预测)
 适合较大规模的数据(比 Word2Vec 更稳定)
学习出的词向量能捕捉语义信息(如类比关系)

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import jieba
from collections import Counter
from scipy.sparse import coo_matrix# ========== 1. 数据预处理 ==========
corpus = ["我们 喜欢 深度 学习","自然 语言 处理 是 有趣 的","人工智能 改变 了 世界","深度 学习 是 人工智能 的 重要 组成部分"
]# 分词
tokenized_corpus = [list(jieba.cut(sentence)) for sentence in corpus]
vocab = set(word for sentence in tokenized_corpus for word in sentence)
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for word, idx in word2idx.items()}# 计算共现矩阵
window_size = 2
co_occurrence = Counter()for sentence in tokenized_corpus:indices = [word2idx[word] for word in sentence]for center_idx in range(len(indices)):center_word = indices[center_idx]for offset in range(-window_size, window_size + 1):context_idx = center_idx + offsetif 0 <= context_idx < len(indices) and context_idx != center_idx:context_word = indices[context_idx]co_occurrence[(center_word, context_word)] += 1# 转换为稀疏矩阵
rows, cols, values = zip(*[(c[0], c[1], v) for c, v in co_occurrence.items()])
X = coo_matrix((values, (rows, cols)), shape=(len(vocab), len(vocab)))# ========== 2. 定义 GloVe 模型 ==========
class GloVe(nn.Module):def __init__(self, vocab_size, embedding_dim):super(GloVe, self).__init__()self.w_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 中心词嵌入self.c_embeddings = nn.Embedding(vocab_size, embedding_dim)  # 上下文词嵌入self.w_bias = nn.Embedding(vocab_size, 1)  # 中心词偏置self.c_bias = nn.Embedding(vocab_size, 1)  # 上下文词偏置nn.init.xavier_uniform_(self.w_embeddings.weight)nn.init.xavier_uniform_(self.c_embeddings.weight)def forward(self, center, context, co_occur):w_emb = self.w_embeddings(center)c_emb = self.c_embeddings(context)w_bias = self.w_bias(center).squeeze()c_bias = self.c_bias(context).squeeze()dot_product = (w_emb * c_emb).sum(dim=1)loss = (dot_product + w_bias + c_bias - torch.log(co_occur + 1e-8)) ** 2return loss.mean()# 初始化模型
embedding_dim = 10
model = GloVe(len(vocab), embedding_dim)# ========== 3. 训练 GloVe ==========
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 100# 转换数据
co_occurrence_tensor = torch.tensor(X.data, dtype=torch.float)
pairs = list(zip(X.row, X.col, co_occurrence_tensor))for epoch in range(num_epochs):total_loss = 0np.random.shuffle(pairs)for center, context, co_occur in pairs:optimizer.zero_grad()loss = model(torch.tensor([center], dtype=torch.long),torch.tensor([context], dtype=torch.long),torch.tensor([co_occur], dtype=torch.float)  # 修正数据类型)loss.backward()optimizer.step()total_loss += loss.item()if (epoch + 1) % 10 == 0:print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {total_loss:.4f}")# ========== 4. 获取词向量 ==========
word_vectors = model.w_embeddings.weight.data.numpy()# ========== 5. 计算相似度 ==========
def most_similar(word, top_n=3):if word not in word2idx:return "单词不在词汇表中"word_vec = word_vectors[word2idx[word]].reshape(1, -1)similarities = np.dot(word_vectors, word_vec.T).squeeze()similar_idx = similarities.argsort()[::-1][1:top_n + 1]return [(idx2word[idx], similarities[idx]) for idx in similar_idx]# 测试
test_words = ["深度", "学习", "人工智能"]
for word in test_words:print(f"【{word}】的相似单词:", most_similar(word))

数据预处理

  • 分词(使用 jieba.cut()
  • 构建共现矩阵(计算窗口内的单词共现频率)
  • 使用稀疏矩阵存储(提高计算效率)

GloVe 模型

  • Embedding 训练词向量(中心词和上下文词分开)
  • Bias 变量 用于调整预测值
  • 损失函数 最小化 log(共现次数) 与词向量点积的差值

 计算词向量相似度

  • 使用 cosine similarity
  • 找出 top_n 最相似的单词
http://www.lryc.cn/news/531037.html

相关文章:

  • SpringCloud篇 微服务架构
  • 背包问题和单调栈
  • Java | CompletableFuture详解
  • 【背包问题】二维费用的背包问题
  • Golang 并发机制-5:详解syn包同步原语
  • 实验六 项目二 简易信号发生器的设计与实现 (HEU)
  • 如何用微信小程序写春联
  • LabVIEW无人机航线控制系统
  • C++哈希表深度解析:从原理到实现,全面掌握高效键值对存储
  • Vue.js组件开发-实现字母向上浮动
  • 自研有限元软件与ANSYS精度对比-Bar2D2Node二维杆单元模型-四连杆实例
  • 04树 + 堆 + 优先队列 + 图(D1_树(D11_伸展树))
  • c语言练习题【数据类型、递归、双向链表快速排序】
  • SliverAppBar的功能和用法
  • 五、定时器实现呼吸灯
  • Elasticsearch的索引生命周期管理
  • 【大模型理论篇】最近大火的DeepSeek-R1初探系列1
  • 【数据结构】(4) 线性表 List
  • 【C++ STL】vector容器详解:从入门到精通
  • OpenAI推出Deep Research带给我们怎样的启示
  • 洛谷[USACO08DEC] Patting Heads S
  • CSS 溢出内容处理:从基础到实战
  • Spring Boot项目如何使用MyBatis实现分页查询
  • 飞行汽车中的无刷外转子电机、人形机器人中的无框力矩电机技术解析与应用
  • FreeRTOS学习 --- 队列集
  • 【R语言】R语言安装包的相关操作
  • 15.[前端开发]Day15-HTML+CSS阶段练习(网易云音乐四)
  • 【基于SprintBoot+Mybatis+Mysql】电脑商城项目之用户登录
  • 测试方案和测试计划相同点和不同点
  • c++提取矩形区域图像的梯度并拟合直线