当前位置: 首页 > news >正文

【GPT-SOVITS-06】特征工程-HuBert原理

说明:该系列文章从本人知乎账号迁入,主要原因是知乎图片附件过于模糊。

知乎专栏地址:
语音生成专栏

系列文章地址:
【GPT-SOVITS-01】源码梳理
【GPT-SOVITS-02】GPT模块解析
【GPT-SOVITS-03】SOVITS 模块-生成模型解析
【GPT-SOVITS-04】SOVITS 模块-鉴别模型解析
【GPT-SOVITS-05】SOVITS 模块-残差量化解析
【GPT-SOVITS-06】特征工程-HuBert原理

1.概述

HuBert 模型目的在于提取音频自编码特征,其核心架构如下:

说明:代码主要参考 HuggingFace 的transformers 开源库

在这里插入图片描述

  • 输入原始音频数据,通过类似Bert原理的编码器形成隐变量,即在进入多头注意力模块前增加了随机的掩码
  • 训练时,第一轮比对原始音频的 MFCC 特征做 kmean 编码,类似残差向量量化网络。针对隐变量与编码做交叉熵损失
  • 训练时,第二轮比对编码器生成的隐变量(第6/9层)做 kmean 编码,再针对隐变量与编码做交叉熵损失

与论文中的截图做一下对比:
在这里插入图片描述
在这里插入图片描述

2.核心源码解析

2.1、特征提取:HubertFeatureEncoder

在这里插入图片描述
默认为 7层一维卷积,每层卷积参数,主要是 kernel 和 stride 不同

2.2、核心编码器:HubertEncoder

在这里插入图片描述

  • 默认为 12层编码器模块
  • 在输出时,包含了最终层的输出,以及中间各层的输出

2.3、有监督微调:HubertForCTC

在这里插入图片描述

  • 论文中同样给出了基于CTC损失的微调
  • 在微调时,特征提取编码器参数固定

CTC 损失的价值,主要是用于输出和标签的不一致性。举例:
假设 hello 这个单词在10秒内完成,则按秒分帧,每一秒对应一个字母的概率。即可能是 hhhhellooo。损失计算的时候是要对比 hhhhellooo 和 hello 的差异。

3、调试代码参考

from transformers import HubertModel, HubertConfig
import torch
import librosa
import torch.nn as nndef _test_pred_vec():config = HubertConfig()model = HubertModel(config)device = "cuda" if torch.cuda.is_available() else "cpu"model.to(device)wav_in = "../data/test.wav"audio, sr  = librosa.load(wav_in, sr=16000)audio = torch.from_numpy(audio).to(device)x = audio[None, :]vec = model.forward(x)print(vec)def _test_ctc_loss():ctc_loss        = nn.CTCLoss()log_probs       = torch.randn(50, 16, 20).log_softmax(2).requires_grad_()targets         = torch.randint(1, 20, (16, 30), dtype=torch.long)input_lengths   = torch.full((16,), 50, dtype=torch.long)target_lengths  = torch.randint(10, 30, (16,), dtype=torch.long)loss            = ctc_loss(log_probs, targets, input_lengths, target_lengths)print(loss)if __name__ == '__main__':#_test_pred_vec()_test_ctc_loss()
http://www.lryc.cn/news/320037.html

相关文章:

  • ros小问题之差速轮式机器人轮子不显示(rviz gazebo)
  • 网络安全实训Day5
  • 【Unity入门】详解Unity中的射线与射线检测
  • 实验11-2-5 链表拼接(PTA)
  • Mybatis Plus + Spring 分包配置 ClickHouse 和 Mysql 双数据源
  • 27-3 文件上传漏洞 - 文件类型绕过(后端绕过)
  • widget一些控件的使用
  • Python基础(七)之数值类型集合
  • 电脑充电器能充手机吗?如何给手机充电?
  • 矩阵中移动的最大次数
  • Linux:系统初始化,内核优化,性能优化(3)
  • 使用 GitHub Actions 通过 CI/CD 简化 Flutter 应用程序开发
  • 微软 CEO Satya Nadella 的访谈
  • 万界星空科技商业开源MES,技术支持+项目合作
  • Docker Mysql无root账户创建最高权限用户
  • 常用芯片学习——DS3231M芯片
  • 蓝桥杯单片机快速开发笔记——矩阵键盘
  • 每周一算法:双向深搜
  • 蓝桥杯刷题(十)
  • ioDraw:与 GitHub、gitee、gitlab、OneDrive 无缝对接,绘图文件永不丢失!
  • 利用 Python 处理遥感影像数据:计算年度平均影像
  • 【Leetcode-73.矩阵置零】
  • redis 常见的异常
  • npm包、全局数据共享、分包
  • UnityShader:IBL
  • 每日五道java面试题之mybatis篇(三)
  • C#开发五子棋游戏:从新手到高手的编程之旅
  • ELK日志管理实现的3种常见方法
  • 深度强化学习01
  • C++ 智能指针的使用