当前位置: 首页 > news >正文

GPT-SoVITS

文章目录

  • model arch
  • S1 Model
  • S2 model

model arch

VITS

  • S1 model: AR model–ssl tokens
  • S2 model: VITS,ssl 已经是mel 长度线性相关,MRTE(ssl_codes_embs, text, global_mel_emb)模块,将文本加强相关,学到一个参考结果

S1 Model

class Text2SemanticDecoder()def forward_old(self, x, x_lens, y, y_lens, bert_feature):"""x: phoneme_idsy: semantic_idsbert_feature: 已经根据word2phn 扩展成和x等长train : y+EOS,已知长度;infer : AR 预测,预测EOS 终止;如果没有,到预设最大长度,终止;"""# phn torch.Size([20, 99]) bert_feature torch.Size([20, 1024, 99])x = self.ar_text_embedding(x)x = x + self.bert_proj(bert_feature.transpose(1, 2))x = self.ar_text_position(x)x_mask = make_pad_mask(x_lens)y_mask = make_pad_mask(y_lens)y_mask_int = y_mask.type(torch.int64)codes = y.type(torch.int64) * (1 - y_mask_int)# Training# AR Decoder: SinePositionalEmbeddingy, targets = self.pad_y_eos(codes, y_mask_int, eos_id=self.EOS)x_len = x_lens.max()y_len = y_lens.max()y_emb = self.ar_audio_embedding(y)y_pos = self.ar_audio_position(y_emb)xy_padding_mask = torch.concat([x_mask, y_mask], dim=1)ar_xy_padding_mask = xy_padding_maskx_attn_mask = F.pad(torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device),(0, y_len),value=True,)y_attn_mask = F.pad(torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device),diagonal=1,),(x_len, 0),value=False,)xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0)bsz, src_len = x.shape[0], x_len + y_len_xy_padding_mask = (ar_xy_padding_mask.view(bsz, 1, 1, src_len).expand(-1, self.num_head, -1, -1).reshape(bsz * self.num_head, 1, src_len))xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask)new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype)new_attn_mask.masked_fill_(xy_attn_mask, float("-inf"))xy_attn_mask = new_attn_mask# x 和完整的 y 一次性输入模型xy_pos = torch.concat([x, y_pos], dim=1)xy_dec, _ = self.h((xy_pos, None),mask=xy_attn_mask,)logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1)# loss# from feiteng: 每次 duration 越多, 梯度更新也应该更多, 所以用 sumloss = F.cross_entropy(logits, targets, reduction="sum")acc = self.ar_accuracy_metric(logits.detach(), targets).item()return loss, acc

S2 model

class Encoder()def forward(self, ssl, y_lengths, text, text_lengths, speed=1,test=None):'''y_lengths: mel_lengthge : ref_encoder_outputs'''ge = self.ref_enc(y * y_mask, y_mask)ssl = self.ssl_proj(ssl)quantized, codes, commit_loss, quantized_list = self.quantizer(ssl, layers=[0])if self.semantic_frame_rate == "25hz":quantized = F.interpolate(quantized, size=int(quantized.shape[-1] * 2), mode="nearest")y = self.encoder_ssl(y * y_mask, y_mask)text_mask = torch.unsqueeze(commons.sequence_mask(text_lengths, text.size(1)), 1).to(y.dtype)if test == 1:text[:, :] = 0text = self.text_embedding(text).transpose(1, 2)text = self.encoder_text(text * text_mask, text_mask)y = self.mrte(y, y_mask, text, text_mask, ge)# self.encoder_ssl, self.encoder_text, self.encoder2 结构一样y = self.encoder2(y * y_mask, y_mask)if(speed!=1):y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")stats = self.proj(y) * y_maskm, logs = torch.split(stats, self.out_channels, dim=1)return y, m, logs, y_mask
http://www.lryc.cn/news/427092.html

相关文章:

  • linux高级编程——文件IO(常用函数大全)
  • matplotlib画图
  • Jetpack 各种框架简介
  • 海康VisionMaster使用学习笔记5-开机自启动
  • 驾驭数据之序:SQL序列的奥秘与实现
  • 【LeetCode】148. 排序链表
  • 阿里云-java调用短信服务,第三方接口的开启(傻瓜式教程)
  • 以node / link文件表征的道路网络-----基于南京公路公开数据做路径规划(下)------dijkstra算法的一些简单花样
  • 计算机操作员中级理论知识试题
  • Redis主从同步配置
  • 输出重定向
  • ubuntu20.04挂载机械硬盘
  • Python轻量级 NoSQL 数据库之tinydb使用详解
  • 【数据结构】二叉树(二)遍历
  • NGINX 常用内置变量
  • Windows采用VS2019实现Open3D的C++应用
  • 冒泡排序、选择排序、插入排序,三种简单排序算法的区别?
  • Docker 日志管理
  • JavaScript初级——基础知识
  • 0817(持久层框架:JDBC,MyBatis)
  • 在亚马逊云科技上安全、合规地创建AI大模型训练基础设施并开发AI应用服务
  • 无人机模拟训练室技术详解
  • 【Spring框架】
  • uniapp 日常业务 随便写写 源码
  • 【软件测试】单元测试20套练习题
  • 8.16 day bug
  • 《Nginx核心技术》第11章:实现MySQL数据库的负载均衡
  • 使用 Gnosis Safe 创建多签名钱包
  • LeetCode 算法:前 K 个高频元素 c++
  • MySQL的SQL语句更新某个字段的值在原来值的基础上随机增加100~600