当前位置: 首页 > news >正文

深度学习中的三种Embedding技术详解

提纲

    • 背景介绍
    • 特征类型与Embedding方法
    • 1. ID类特征的Embedding处理
      • 1.1 标准Embedding方法
      • 1.2 IdHashEmbedding方法
    • 2. 数值型特征的Embedding处理
      • 2.1 RawEmbedding方法
    • 三种Embedding方法对比总结
    • 实践建议
    • 总结

背景介绍

在深度学习领域,Embedding(嵌入)技术是一种将高维稀疏数据转换为低维稠密向量表示的核心方法。它在推荐系统、自然语言处理、图像识别等多个领域中发挥着重要作用。

Embedding的主要目的是:

  1. 将离散的类别特征(如用户ID、商品类别)转换为连续的向量表示
  2. 在低维空间中捕获特征间的语义关系
  3. 提高模型的表达能力和泛化性能

正确使用Embedding技术对于模型性能至关重要。本文将详细解析三种不同的Embedding方法及其应用场景。

特征类型与Embedding方法

在实际项目中,我们通常会遇到两种主要的特征类型:

  1. ID类特征:如用户ID、骑手ID、商品类别等离散的标识符
  2. 数值型特征:如用户年龄、订单金额等连续数值

针对不同类型的特征,我们需要采用不同的Embedding策略来处理。

1. ID类特征的Embedding处理

1.1 标准Embedding方法

适用场景:适用于中低基数类别特征,如商品类别、用户性别、服务类型、地区编码等。

工作原理

  • 为每个唯一的特征值维护一个独立的向量表示
  • 直接使用特征值作为索引查找Embedding table中对应的Embedding向量
  • Embedding表的大小(行数)需要大于等于特征最大的特征值

Example
假设我们有一个item_category特征,有50个不同的类别,我们希望将其映射到16维的向量空间:

import torch
import torch.nn as nn
# 特征配置
feature_config = {'item_category': {'hash_bucket_size': 50, 'embedding_dimension': 16}
}# 创建Embedding层(table),维度为50X16的矩阵
embedding_layer = nn.Embedding(num_embeddings=50,     # 类别数量embedding_dim=16       # 目标维度
)# 输入样本:两个订单的类别分别为5和23
input_tensor = torch.tensor([5, 23])  # shape: [2]# Embedding过程就时vlookup过程,从embedding_layer 矩阵中选取第5行和第23行数据向量
embedded = embedding_layer(input_tensor)  # shape: [2, 16]print(f"输入形状: {input_tensor.shape}")    # [2]
print(f"输出形状: {embedded.shape}")       # [2, 16]

初始化方式

- PyTorch中nn.Embedding的默认初始化
- 权重矩阵形状为[50, 16],每个元素从标准正态分布N(0,1)中采样,weight[i, j] ~ N(0, 1) for all i in [0, 49] and j in [0, 15]

优缺点分析

  • 优点:无哈希冲突,每个特征值有独立表示,易于解释
  • 缺点:对于高基数特征内存消耗巨大,无法处理训练时未见过的特征值大于embedding table的情况,item_category例中,假设有一个item_category的数值为51,则运行时会报突破边界错,因为从embedding table中找不到第51行。

1.2 IdHashEmbedding方法

适用场景:适用于高基数类别特征,如用户ID、商品ID、骑手ID等具有大量唯一值的特征。

为什么需要IdHashEmbedding?
当面对高基数特征时,标准Embedding方法会遇到以下问题:

  1. 内存爆炸:如果用户ID有千万/亿级别,直接Embedding需要巨大的内存
  2. 未见特征值:遇到特征值大于embedding table行数时报错
  3. 训练效率低:大量参数需要更新,训练速度慢

工作原理

  • 通过哈希函数将原始特征值映射到固定大小的桶中
  • 使用哈希后的值作为索引从Embedding表中查找向量
  • 内存使用固定,不受原始特征基数影响

Example

class IdHashEmbedding(nn.Module):...def _build_embeddings(self):for feat_name, config in self.feature_config.items():self.embeddings[feat_name] = nn.Embedding(num_embeddings=config['hash_bucket_size'],embedding_dim=config['embedding_dimension'])def forward(self, features):embed_list = []for feat_name, config in self.feature_config.items():id_x = features[feat_name]hash_x = id_x.cpu().apply_(lambda x: hash_bucket(x, config['hash_bucket_size'])).long().to(self.device)     embed_list.append(self.embeddings[feat_name](hash_x))embeds = torch.cat(embed_list, dim=1)return embeds# 用户ID特征,原始ID可能达到千万级别
user_ids = torch.tensor([123456, 789012])# 使用IdHashEmbedding压缩到1000个桶中
id_hash_embedding = IdHashEmbedding({'user_id': {'hash_bucket_size': 1000, 'embedding_dimension': 32}
}, device)# 计算Id特征对应的IdHash值 
hash_bucket(123456, 1000) = 123456 % 1000 = 456, 
hash_bucket(789012, 1000) = 789012 % 1000 = 12
# 实际查找的是索引456和12,而非原始ID
embedded = id_hash_embedding({'user_id': user_ids})

优缺点分析

  • 优点:内存高效,可以处理任意基数的特征,能处理未见过的特征值
  • 缺点:可能产生哈希冲突,不同特征值可能映射到同一向量,比如对用户223456,其IdHash值也是456,和用户ID123456对应的Id Hash值相同,在模型训练时会当作相同的特征

2. 数值型特征的Embedding处理

2.1 RawEmbedding方法

适用场景:适用于连续数值特征,如用户年龄、注册时长、订单金额等。

工作原理

  • 使用线性变换将原始数值映射到目标维度
  • 相比查找表方式,更适合处理连续值

Example

class RawEmbedding(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.input_dim = input_dimself.linear = nn.Linear(input_dim, output_dim)def forward(self, x):return self.linear(x)# 用户年龄特征
user_ages = torch.tensor([[25.0], [35.0]])  # shape: [2, 1]# 使用RawEmbedding将1维年龄映射到8维向量
raw_embedding = RawEmbedding(input_dim=1, output_dim=8)
embedded = raw_embedding(user_ages)  # shape: [2, 8]print(f"输入形状: {user_ages.shape}")   # [2, 1]
print(f"输出形状: {embedded.shape}")     # [2, 8]

优缺点分析

  • 优点:适合处理连续特征,计算简单高效,参数量少
  • 缺点:不适合处理未编码的类别特征,表达能力有限

三种Embedding方法对比总结

特征类型Embedding方法适用场景内存消耗处理未见特征哈希冲突表达能力计算复杂度
ID类特征Embedding中低基数类别特征(商品类别、服务类型等)
ID类特征IdHashEmbedding高基数类别特征(用户ID、骑手ID等)
数值特征RawEmbedding连续数值特征(年龄、服务时长等)

实践建议

在骑手等级转换预测项目中,我们建议:

  1. 高基数ID特征(如骑手ID、用户ID)使用IdHashEmbedding
  2. 中低基数ID特征(如服务类型、地区编码)使用Embedding
  3. 连续数值特征(如年龄、注册时长)使用RawEmbedding

合理选择和使用Embedding方法,可以显著提升模型性能,同时控制内存消耗。

总结

Embedding技术是深度学习模型中不可或缺的组成部分。通过合理选择IdHashEmbedding、Embedding和RawEmbedding,我们可以有效处理不同类型的特征,构建高性能的预测模型。理解每种方法的原理和适用场景,对于构建成功的机器学习系统具有重要意义。

在实际应用中,我们需要根据特征的类型、基数大小、内存限制以及业务需求来选择合适的Embedding方法,这样才能在模型效果和计算效率之间取得最佳平衡。

http://www.lryc.cn/news/609262.html

相关文章:

  • Maven - 依赖的生命周期详解
  • MySQL深度理解-MySQL锁机制
  • vllm0.8.5:思维链(Chain-of-Thought, CoT)微调模型的输出结果包括</think>,提供一种关闭思考过程的方法
  • Remix框架:高性能React全栈开发实战
  • 音视频学习(四十九):音频有损压缩
  • 数据结构-双链表
  • 网络通信与Socket套接字详解
  • Flink程序关键一步:触发环境执行
  • 13-day10生成式任务
  • 全面解析 BGE Embedding 模型:训练方式、模型系列与实战用法
  • python批量gif图片转jpg
  • C++ vector容器详解:从基础使用到高效实践
  • 【GitHub探索】Agent开发平台CozeStudio开源版本踩坑体验
  • Obsidian结合CI/CD实现自动发布
  • 【设计模式】4.装饰器模式
  • 第二节 YOLOv5参数
  • 电商系统定制开发流程:ZKmall开源商城需求分析到上线全程可控
  • Linux命令基础(上)
  • 关于Web前端安全防御之内容安全策略(CSP)
  • 第八章:进入Redis的SET的核心
  • 基于 Spring Boot + Vue 实现人脸采集功能全流程
  • spring-ai-alibaba 之 graph 槽点
  • 无人机集群协同三维路径规划,采用冠豪猪优化器(Crested Porcupine Optimizer, CPO)实现,Matlab代码
  • 基于BiLSTM+CRF实现NER
  • JavaWeb学习------SpringCloud入门
  • 最小半径覆盖问题【C++解法+二分+扫描线】
  • 研报复现|史蒂夫·路佛价值选股法则
  • [硬件电路-138]:模拟电路 - 什么是正电源?什么是负电源?集成运放为什么有VCC+和VCC-
  • 【RH124 问答题】第 8 章 监控和管理 Linux 进程
  • MySQL学习之MVCC多版本并发控制