当前位置: 首页 > news >正文

每日Attention学习23——KAN-Block

模块出处

[SPL 25] [link] [code] KAN See In the Dark


模块名称

Kolmogorov-Arnold Network Block (KAN-Block)


模块作用

用于vision的KAN结构


模块结构

在这里插入图片描述


模块代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathclass Swish(nn.Module):def forward(self, x):return x * torch.sigmoid(x)class KANLinear(torch.nn.Module):def __init__(self,in_features,out_features,grid_size=5,spline_order=3,scale_noise=0.1,scale_base=1.0,scale_spline=1.0,enable_standalone_scale_spline=True,base_activation=torch.nn.SiLU,grid_eps=0.02,grid_range=[-1, 1],):super(KANLinear, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.grid_size = grid_sizeself.spline_order = spline_orderself.weight = nn.Parameter(torch.Tensor(out_features, in_features))self.bias = nn.Parameter(torch.Tensor(out_features))h = (grid_range[1] - grid_range[0]) / grid_sizegrid = ((torch.arange(-spline_order, grid_size + spline_order + 1) * h+ grid_range[0]).expand(in_features, -1).contiguous())self.register_buffer("grid", grid)self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.spline_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features, grid_size + spline_order))if enable_standalone_scale_spline:self.spline_scaler = torch.nn.Parameter(torch.Tensor(out_features, in_features))self.scale_noise = scale_noiseself.scale_base = scale_baseself.scale_spline = scale_splineself.enable_standalone_scale_spline = enable_standalone_scale_splineself.base_activation = base_activation()self.grid_eps = grid_epsself.reset_parameters()def reset_parameters(self):torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)with torch.no_grad():noise = ((torch.rand(self.grid_size + 1, self.in_features, self.out_features)- 1 / 2)* self.scale_noise/ self.grid_size)self.spline_weight.data.copy_((self.scale_spline if not self.enable_standalone_scale_spline else 1.0)* self.curve2coeff(self.grid.T[self.spline_order : -self.spline_order],noise,))if self.enable_standalone_scale_spline:# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)def b_splines(self, x: torch.Tensor):"""Compute the B-spline bases for the given input tensor.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).Returns:torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresgrid: torch.Tensor = (self.grid)  # (in_features, grid_size + 2 * spline_order + 1)x = x.unsqueeze(-1)bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)for k in range(1, self.spline_order + 1):bases = ((x - grid[:, : -(k + 1)])/ (grid[:, k:-1] - grid[:, : -(k + 1)])* bases[:, :, :-1]) + ((grid[:, k + 1 :] - x)/ (grid[:, k + 1 :] - grid[:, 1:(-k)])* bases[:, :, 1:])assert bases.size() == (x.size(0),self.in_features,self.grid_size + self.spline_order,)return bases.contiguous()def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):"""Compute the coefficients of the curve that interpolates the given points.Args:x (torch.Tensor): Input tensor of shape (batch_size, in_features).y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).Returns:torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order)."""assert x.dim() == 2 and x.size(1) == self.in_featuresassert y.size() == (x.size(0), self.in_features, self.out_features)A = self.b_splines(x).transpose(0, 1)  # (in_features, batch_size, grid_size + spline_order)B = y.transpose(0, 1)  # (in_features, batch_size, out_features)solution = torch.linalg.lstsq(A, B).solution  # (in_features, grid_size + spline_order, out_features)result = solution.permute(2, 0, 1)  # (out_features, in_features, grid_size + spline_order)assert result.size() == (self.out_features,self.in_features,self.grid_size + self.spline_order,)return result.contiguous()@propertydef scaled_spline_weight(self):return self.spline_weight * (self.spline_scaler.unsqueeze(-1)if self.enable_standalone_scale_splineelse 1.0)def forward(self, x: torch.Tensor):assert x.dim() == 2 and x.size(1) == self.in_featuresbase_output = F.linear(self.base_activation(x), self.base_weight)spline_output = F.linear(self.b_splines(x).view(x.size(0), -1),self.scaled_spline_weight.view(self.out_features, -1),)return base_output + spline_output@torch.no_grad()def update_grid(self, x: torch.Tensor, margin=0.01):assert x.dim() == 2 and x.size(1) == self.in_featuresbatch = x.size(0)splines = self.b_splines(x)  # (batch, in, coeff)splines = splines.permute(1, 0, 2)  # (in, batch, coeff)orig_coeff = self.scaled_spline_weight  # (out, in, coeff)orig_coeff = orig_coeff.permute(1, 2, 0)  # (in, coeff, out)unreduced_spline_output = torch.bmm(splines, orig_coeff)  # (in, batch, out)unreduced_spline_output = unreduced_spline_output.permute(1, 0, 2)  # (batch, in, out)# sort each channel individually to collect data distributionx_sorted = torch.sort(x, dim=0)[0]grid_adaptive = x_sorted[torch.linspace(0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device)]uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_sizegrid_uniform = (torch.arange(self.grid_size + 1, dtype=torch.float32, device=x.device).unsqueeze(1)* uniform_step+ x_sorted[0]- margin)grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptivegrid = torch.concatenate([grid[:1]- uniform_step* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),grid,grid[-1:]+ uniform_step* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),],dim=0,)self.grid.copy_(grid.T)self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):"""Compute the regularization loss.This is a dumb simulation of the original L1 regularization as stated in thepaper, since the original one requires computing absolutes and entropy from theexpanded (batch, in_features, out_features) intermediate tensor, which is hiddenbehind the F.linear function if we want an memory efficient implementation.The L1 regularization is now computed as mean absolute value of the splineweights. The authors implementation also includes this term in addition to thesample-based regularization."""l1_fake = self.spline_weight.abs().mean(-1)regularization_loss_activation = l1_fake.sum()p = l1_fake / regularization_loss_activationregularization_loss_entropy = -torch.sum(p * p.log())return (regularize_activation * regularization_loss_activation+ regularize_entropy * regularization_loss_entropy)class DW_bn_relu(nn.Module):def __init__(self, dim=768):super(DW_bn_relu, self).__init__()self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)self.bn = nn.BatchNorm2d(dim)self.relu = nn.ReLU()def forward(self, x, H, W):B, N, C = x.shapex = x.transpose(1, 2).view(B, C, H, W)x = self.dwconv(x)x = self.bn(x)x = self.relu(x)x = x.flatten(2).transpose(1, 2)return xclass KANBlock(nn.Module):def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., shift_size=5, version=4):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresself.dim = in_featuresgrid_size=5spline_order=3scale_noise=0.1scale_base=1.0scale_spline=1.0base_activation=torch.nn.SiLUgrid_eps=0.02grid_range=[-1, 1]self.fc1 = KANLinear(in_features,hidden_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc2 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)self.fc3 = KANLinear(hidden_features,out_features,grid_size=grid_size,spline_order=spline_order,scale_noise=scale_noise,scale_base=scale_base,scale_spline=scale_spline,base_activation=base_activation,grid_eps=grid_eps,grid_range=grid_range,)   self.dwconv_1 = DW_bn_relu(hidden_features)self.dwconv_2 = DW_bn_relu(hidden_features)self.dwconv_3 = DW_bn_relu(hidden_features)self.drop = nn.Dropout(drop)self.shift_size = shift_sizeself.pad = shift_size // 2def forward(self, x, H, W):B, N, C = x.shapex = self.fc1(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_1(x, H, W)x = self.fc2(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_2(x, H, W)x = self.fc3(x.reshape(B*N,C))x = x.reshape(B,N,C).contiguous()x = self.dwconv_3(x, H, W)return xif __name__ == '__main__':x = torch.randn([1, 22*22, 128])kan = KANBlock(in_features=128)out = kan(x, H=22, W=22)print(out.shape)  # [1, 22*22, 128]

http://www.lryc.cn/news/537559.html

相关文章:

  • 今日写题04work
  • Managed Lustre 和 WEKA:高性能文件系统的对比与应用
  • LeetCode541 反转字符串2
  • MAC 系统关闭屏幕/睡眠 后被唤醒 Wake Requests
  • 论文笔记:Multi-Head Mixture-of-Experts
  • vue和Django快速创建项目
  • Java LinkedList(单列集合)
  • 多线程基础面试题剖析
  • .NET SixLabors.ImageSharp v1.0 图像实用程序控制台示例
  • EasyExcel提取excel文档
  • 第十五届蓝桥杯嵌入式省赛真题(满分)
  • ASP.NET Core Web应用(.NET9.0)读取数据库表记录并显示到页面
  • 【Sceneform-EQR】实现3D场景背景颜色的定制化(背景融合的方式、Filament材质定制)
  • LeetCode1706
  • 2517. 礼盒的最大甜蜜度(Maximum Tastiness of Candy Box)
  • Golang 的字符编码与 regexp
  • 利用ollama 与deepseek r1大模型搭建本地知识库
  • Java短信验证功能简单使用
  • CAS单点登录(第7版)21.可接受的使用政策
  • 53倍性能提升!TiDB 全局索引如何优化分区表查询?
  • Pythong 解决Pycharm 运行太慢
  • 库里存储的数据有大量回车时,该如何进行存取
  • 【devops】Github Actions Secrets | 如何在Github中设置CI的Secret供CI的yaml使用
  • 体验 DeepSeek-R1:解密 1.5B、7B、8B 版本的强大性能与应用
  • 一文说清楚什么是Token以及项目中使用Token延伸的问题
  • 大模型-Tool call、检索增强
  • 【算法】【区间和】acwing算法基础 802. 区间和 【有点复杂,但思路简单】
  • Ubuntu22.04通过Docker部署Jeecgboot
  • HTML4
  • STM32F10X 启动文件完整分析