当前位置: 首页 > news >正文

深度学习基础模型之Mamba

Mamba模型简介

问题:许多亚二次时间架构(运行时间复杂度低于O(n^2),但高于O(n)的情况)(例如线性注意力、门控卷积和循环模型以及结构化状态空间模型(SSM))已被开发出来,以解决 Transformer 在长序列上的计算效率低下问题,但此类模型的一个关键弱点是它们无法执行基于内容的推理

1. 模型架构

模型简单理解(特殊的门控RNN网络):线性层+门控+选择性SSM的组合

在这里插入图片描述

2. 模型特点

2.1 选择性机制

在这里插入图片描述

Δ \Delta Δ 、A、B、C应该是SSM中的可学习参数

  • 根据输入参数化 SSM 参数来设计一种简单的选择机制,这使得模型能够过滤掉不相关的信息并无限期地记住相关信息。
    这里作者认为(研究动机):‘序列建模的一个基本问题是将上下文压缩成更小的状态。事实上,我们可以从这个角度来看待流行序列模型的权衡。例如,注意力既有效又低效,因为它明确地根本不压缩上下文。自回归推理需要显式存储整个上下文(即KV缓存),这直接导致Transformers的线性时间推理和二次时间训练缓慢。’
    在这里插入图片描述
  • 序列模型的效率与有效性权衡的特征在于它们压缩状态的程度:高效模型必须具有较小的状态,而有效模型必须具有包含上下文中所有必要信息的状态。反过来,我们提出构建序列模型的基本原则是选择性:或关注或过滤掉序列状态输入的上下文感知能力。

2.2 硬件算法

算法通过扫描而不是卷积来循环计算模型,但不会具体化扩展状态,计算速度比所有先前的 SSM 模型提升三倍。

代码调用

import torch
from mamba_ssm import Mambabatch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(# This module uses roughly 3 * expand * d_model^2 parametersd_model=dim, # Model dimension d_modeld_state=16,  # SSM state expansion factord_conv=4,    # Local convolution widthexpand=2,    # Block expansion factor
).to("cuda")
y = model(x)
print(x.shape)
print(y.shape)
assert y.shape == x.shape

总结

这项基础性模型研究旨在解决transformer模型的长序列数据计算效率低的问题,其解决方法的动机:利用选择性机制实现有效特征的提取。个人理解为通过有效特征信息的选择实现知识提取(信息压缩),这让我联想到,最初的VGG语义分割网络结构设计其实类似于模拟知识特征的压缩与抽取,但后来发现这种方式会损失边缘信息,因此提出了U-net架构,再进一步卷积的方式无法有效估计全局上下文信息的联系,进而提出注意力机制来解决这一问题。
从技术与文章写作的角度来看,问题的发展似乎从知识压缩->细节特征提取->全局信息整合,到Mamba貌似是在全局信息整合基础上在进行一次有效信息的抽取,进而使模型从数据中提取根据代表性的特征。整体突出一点:深度学习也是一个特征工程,利用模型来替换原有的手工设计的特征

  • 详细代码链接
  • 相关模型应用案例:U-Mamba
    在这里插入图片描述
http://www.lryc.cn/news/331048.html

相关文章:

  • Topaz Video AI for Mac v5.0.0激活版 视频画质增强软件
  • 解决WordPress文章的段落首行自动空两格的问题
  • RISC-V单板计算机模拟和FPGA板多核IP实现
  • Mojo编程语言案例及介绍
  • 【Python面试题收录】Python中有哪些方法交换两个变量的值?至少给出三种方法。
  • MySQL核心命令详解与实战,一文掌握MySQL使用
  • 基于Springboot + MySQL + Vue 大学新生宿舍管理系统 (含源码)
  • vulnhub pWnOS v2.0通关
  • leetcode热题100.数据流的中位数
  • C 从函数返回指针
  • (文章复现)考虑分布式电源不确定性的配电网鲁棒动态重构
  • 蓝桥杯第八届c++大学B组详解
  • 小于n的最大数 Leetcode 902 Numbers At Most N Given Digit Set
  • Leetcode刷题-数组(二分法、双指针法、窗口滑动)
  • STM32学习和实践笔记(4): 分析和理解GPIO_InitTypeDef GPIO_InitStructure (b)
  • 数据仓库——事实表
  • 人工智能常用的编程语言有哪些?
  • 【Leetcode每日一题】模拟 - 提莫攻击(难度⭐)(45)
  • OPPO云VPC网络实践
  • 力扣(数组)找到所有数组中消失的数字
  • 每日面经分享(Spring Boot: part3 Service层)
  • k8s的pod访问service的方式
  • shell脚本发布docker-nginx vue2 项目示例
  • 【THM】Nmap Basic Port Scans(基本端口扫描)-初级渗透测试
  • Groovy结合Java在生产中的落地实战
  • 达梦数据库 创建外部表 [-7082]:外部表数据错误.
  • XUbuntu22.04之激活Linux最新Typora版本(二百二十五)
  • JavaScript简介
  • 使用PaddleX实现的智慧农业病虫检测项目
  • 算法学习——LeetCode力扣图论篇1(797. 所有可能的路径、200. 岛屿数量、695. 岛屿的最大面积)