当前位置: 首页 > news >正文

Youtube双塔模型

1. 引言

在大规模推荐系统中,如何从海量候选物品中高效检索出用户可能感兴趣的物品是一个关键问题。传统的矩阵分解方法在处理稀疏数据和长尾分布时面临挑战。本文介绍了一种基于双塔神经网络的建模框架,通过采样偏差校正技术提升推荐质量,并成功应用于YouTube视频推荐系统。

2. 建模框架

论文的亮点不在于提出了新的架构,而是针对训练时负采样的处理。

  • 模型架构
    在这里插入图片描述

2.1 问题定义

给定查询(用户和上下文)和物品的特征表示:

  • 查询特征: x i ∈ X x_i \in \mathcal{X} xiX
  • 物品特征: y j ∈ Y y_j \in \mathcal{Y} yjY

目标是通过双塔神经网络学习嵌入函数:
u : X × R d → R k v : Y × R d → R k \begin{aligned} u &: \mathcal{X} \times \mathbb{R}^d \rightarrow \mathbb{R}^k \\ v &: \mathcal{Y} \times \mathbb{R}^d \rightarrow \mathbb{R}^k \end{aligned} uv:X×RdRk:Y×RdRk
其中模型参数 θ ∈ R d \theta \in \mathbb{R}^d θRd,输出为内积得分:
s ( x , y ) = ⟨ u ( x , θ ) , v ( y , θ ) ⟩ s(x,y) = \langle u(x,\theta), v(y,\theta) \rangle s(x,y)=u(x,θ),v(y,θ)⟩

2.2 损失函数

将推荐视为带连续奖励的多分类问题,采用softmax概率:
P ( y ∣ x ; θ ) = e s ( x , y ) ∑ j ∈ [ M ] e s ( x , y j ) \mathcal{P}(y|x;\theta) = \frac{e^{s(x,y)}}{\sum_{j\in[M]} e^{s(x,y_j)}} P(yx;θ)=j[M]es(x,yj)es(x,y)

加权对数似然损失:
L T ( θ ) = − 1 T ∑ i ∈ [ T ] r i ⋅ log ⁡ ( P ( y i ∣ x i ; θ ) ) L_T(\theta) = -\frac{1}{T}\sum_{i\in[T]} r_i \cdot \log(\mathcal{P}(y_i|x_i;\theta)) LT(θ)=T1i[T]rilog(P(yixi;θ))

2.3 批处理softmax

当物品集 M M M极大时,计算完整softmax不可行。一种常见的方法是使用一批项目的一个子集,特别是对于来自同一批次的所有查询,使用批次内的项目作为负样本。给定一个包含 B 对 ( x i , y i , r i ) i = 1 B {(x_i,y_i,r_i)}_{i=1}^{B} (xi,yi,ri)i=1B 的小批次,批次 softmax 为:

P B ( y i ∣ x i ; θ ) = e s ( x i , y i ) ∑ j ∈ [ B ] e s ( x i , y j ) \mathcal{P}_B(y_i|x_i;\theta) = \frac{e^{s(x_i,y_i)}}{\sum_{j\in[B]} e^{s(x_i,y_j)}} PB(yixi;θ)=j[B]es(xi,yj)es(xi,yi)

2.4 采样偏差校正

流行物品因高频出现在批次中会被过度惩罚,引入对数校正项:
s c ( x i , y j ) = s ( x i , y j ) − log ⁡ ( p j ) s^c(x_i,y_j) = s(x_i,y_j) - \log(p_j) sc(xi,yj)=s(xi,yj)log(pj)
其中 p j p_j pj为物品 j j j的采样概率。

校正后的损失函数:
L B ( θ ) = − 1 B ∑ i ∈ [ B ] r i ⋅ log ⁡ ( e s c ( x i , y i ) e s c ( x i , y i ) + ∑ j ≠ i e s c ( x i , y j ) ) L_B(\theta) = -\frac{1}{B}\sum_{i\in[B]} r_i \cdot \log\left(\frac{e^{s^c(x_i,y_i)}}{e^{s^c(x_i,y_i)} + \sum_{j\neq i}e^{s^c(x_i,y_j)}}\right) LB(θ)=B1i[B]rilog(esc(xi,yi)+j=iesc(xi,yj)esc(xi,yi))

3. 流式频率估计算法

3.1 核心思想

通过全局步长(global step)估计物品采样间隔 δ \delta δ,进而计算采样概率:
p = 1 δ p = \frac{1}{\delta} p=δ1

3.2 算法实现

在这里插入图片描述

数据结构:

  • 哈希数组 A A A:记录物品最后一次出现的步长
  • 哈希数组 B B B:估计物品的采样间隔 δ \delta δ

更新规则(SGD形式):
B [ h ( y ) ] ← ( 1 − α ) ⋅ B [ h ( y ) ] + α ⋅ ( t − A [ h ( y ) ] ) B[h(y)] \leftarrow (1-\alpha) \cdot B[h(y)] + \alpha \cdot (t - A[h(y)]) B[h(y)](1α)B[h(y)]+α(tA[h(y)])

数学性质:

  • 偏差分析:
    E ( δ t ) − δ = ( 1 − α ) t δ 0 − ( 1 − α ) t − 1 δ \mathbb{E}(\delta_t) - \delta = (1-\alpha)^t\delta_0 - (1-\alpha)^{t-1}\delta E(δt)δ=(1α)tδ0(1α)t1δ

  • 方差上界:
    E [ ( δ t − E [ δ t ] ) 2 ] ≤ ( 1 − α ) 2 t ( δ 0 − δ ) 2 + α E [ ( Δ 1 − δ ) 2 ] \mathbb{E}[(\delta_t - \mathbb{E}[\delta_t])^2] \leq (1-\alpha)^{2t}(\delta_0 - \delta)^2 + \alpha\mathbb{E}[(\Delta_1 - \delta)^2] E[(δtE[δt])2](1α)2t(δ0δ)2+αE[(Δ1δ)2]

  • 各学习率误差的结果
    在这里插入图片描述

3.3 多哈希优化

使用 m m m个哈希函数减少碰撞误差:
p ^ = max ⁡ 1 ≤ i ≤ m 1 B i [ h i ( y ) ] \hat{p} = \max_{1\leq i\leq m} \frac{1}{B_i[h_i(y)]} p^=1immaxBi[hi(y)]1

4. 在YouTube推荐系统中的应用

4.1 系统架构

  • 查询塔:融合用户观看历史和种子视频特征
  • 候选塔:处理候选视频内容特征
  • 共享特征嵌入提升训练效率

4.2 关键创新

  1. 流式训练:按天顺序消费数据,适应分布变化
  2. 哈希桶技术:处理新出现的内容ID
  3. 索引管道:量化哈希技术加速最近邻搜索

5.Trick

在相似度得分上添加温度参数 τ \tau τ
此外,为了使预测结果更加尖锐,我们在每个 logit 上添加了一个温度参数 τ。具体来说,我们使用以下公式计算查询和候选项目之间的相似度得分:
s ( x , y ) = ⟨ u ( x , θ ) , v ( y , θ ) ⟩ τ s(x,y)= \frac{⟨u(x,θ),v(y,θ)⟩}{\tau} s(x,y)=τu(x,θ),v(y,θ)⟩

Youtube实验结果
在这里插入图片描述

6. 结论

本文提出的采样偏差校正方法通过:

  1. 理论保证的无偏频率估计
  2. 适应动态变化的流式环境
  3. 可扩展的分布式实现

在十亿级物品的推荐场景中显著提升了检索质量,为大规模内容推荐提供了新的解决方案。

引用

Yi, X., Yang, J., Hong, L., Cheng, D. Z., Heldt, L., Kumthekar, A., Zhao, Z., Wei, L., & Chi, E. (2019). Sampling-bias-corrected neural modeling for large corpus item recommendations. Proceedings of the 13th ACM Conference on Recommender Systems, 269–277.

http://www.lryc.cn/news/576703.html

相关文章:

  • C++共享型智能指针std::shared_ptr使用介绍
  • cocos creator 3.8 - 精品源码 - 挪车超人(挪车消消乐)
  • Neo4j无法建立到 localhost:7474 服务器的连接出现404错误
  • Linux基本命令篇 —— less命令
  • springboot+Vue驾校管理系统
  • matplotlib 绘制水平柱状图
  • 基于LQR控制器的六自由度四旋翼无人机模型simulink建模与仿真
  • 使用deepseek制作“喝什么奶茶”随机抽签小网页
  • 我的世界模组开发进阶教程——机械动力的数据生成(2)
  • 【C++进阶】--- 继承
  • 基于WOA鲸鱼优化算法的圆柱体容器最大体积优化设计matlab仿真
  • 人大金仓数据库jdbc连接jar包kingbase8-8.6.0.jar驱动包最新版下载(不需要积分)
  • C++泛型编程2 - 类模板
  • C# 委托(为委托添加方法和从委托移除方法)
  • 13-StringBuilder类的使用
  • Linux内核网络协议栈深度解析:面向连接的INET套接字实现
  • 8. 【Vue实战--孢子记账--Web 版开发】-- 账户账本管理
  • Uni-App 小程序面试题高频问答汇总
  • 【Docker基础】Docker容器管理:docker top及其参数详解
  • Ubuntu 主机通过 `enp4s0` 向开发板共享网络的完整步骤
  • Flutter基础(控制器)
  • 广外计算机网络期末复习
  • 大模型之提示词工程入门——解锁与AI高效沟通的“钥匙”
  • WOE值:风险建模中的“证据权重”量化术——从似然比理论到FICO评分卡实践
  • python学习打卡day57
  • Python基础(吃洋葱小游戏)
  • 如何让ChatGPT模仿人类写作,降低AIGC率?
  • SpringBoot3.x整合Knife4j接口文档
  • cocos creator 3.8 - 精品源码 - 六边形消消乐(六边形叠叠乐、六边形堆叠战士)
  • 阿里 Qwen3 模型更新,吉卜力风格get