当前位置: 首页 > news >正文

Pyraformer复现心得

Pyraformer复现心得

引用

Liu, Shizhan, et al. “Pyraformer: Low-complexity pyramidal attention for long-range time series modeling and forecasting.” International conference on learning representations. 2021.

代码部分

def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]#B,dmodel*3dec_out = self.projection(enc_out).view(enc_out.size(0), self.pred_len, -1)#B,pre,Nreturn dec_out

预测部分就这么长

x_dec, x_mark_dec, mask=None都没用到

enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
#B,dmodel*3
  • 直接进入encoder
def forward(self, x_enc, x_mark_enc):seq_enc = self.enc_embedding(x_enc, x_mark_enc)
  • 重构了encoder和decoder,跟transformer的很不一样
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
  • embedding方法跟former一样
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)

用pyra的方式获取pam掩码

def get_mask(input_size, window_size, inner_size):"""Get the attention mask of PAM-Naive"""# Get the size of all layersall_size = []all_size.append(input_size)for i in range(len(window_size)):layer_size = math.floor(all_size[i] / window_size[i])all_size.append(layer_size)seq_length = sum(all_size)mask = torch.zeros(seq_length, seq_length)# get intra-scale maskinner_window = inner_size // 2for layer_idx in range(len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = max(i - inner_window, start)right_side = min(i + inner_window + 1, start + all_size[layer_idx])mask[i, left_side:right_side] = 1# get inter-scale maskfor layer_idx in range(1, len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = (start - all_size[layer_idx - 1]) + \(i - start) * window_size[layer_idx - 1]if i == (start + all_size[layer_idx] - 1):right_side = startelse:right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]mask[i, left_side:right_side] = 1mask[left_side:right_side, i] = 1mask = (1 - mask).bool()return mask, all_size

接着进入卷积层

seq_enc = self.conv_layers(seq_enc)

先构建CSCM卷积

class Bottleneck_Construct(nn.Module):"""Bottleneck convolution CSCM"""
temp_input = self.down(enc_input).permute(0, 2, 1)
all_inputs = []
self.down = Linear(d_model, d_inner)

下采样

for i in range(len(self.conv_layers)):temp_input = self.conv_layers[i](temp_input)all_inputs.append(temp_input)

堆叠很多次卷积,这个跟former是一样的

class ConvLayer(nn.Module):def __init__(self, c_in, window_size):super(ConvLayer, self).__init__()self.downConv = nn.Conv1d(in_channels=c_in,out_channels=c_in,kernel_size=window_size,stride=window_size)self.norm = nn.BatchNorm1d(c_in)self.activation = nn.ELU()def forward(self, x):x = self.downConv(x)x = self.norm(x)x = self.activation(x)return x

将N次卷积的结果拼接起来

all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)#
all_inputs = self.up(all_inputs)
all_inputs = torch.cat([enc_input, all_inputs], dim=1)
self.up = Linear(d_inner, d_model)
all_inputs = self.norm(all_inputs)
return all_inputs
self.norm = nn.LayerNorm(d_model)

之后在跟原始输入拼接起来

  • 卷积layer完了之后是encoderlayer
def forward(self, enc_input, slf_attn_mask=None):attn_mask = RegularMask(slf_attn_mask)
enc_output, _ = self.slf_attn(enc_input, enc_input, enc_input, attn_mask=attn_mask)

进到encoder里面,到了熟悉的former框架

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):#后面俩参数应该是作者指定的B, L, _ = queries.shape#B,seq,dmodel_, S, _ = keys.shapeH = self.n_heads
#其实L和S是一个数queries = self.query_projection(queries).view(B, L, H, -1)#B, L, H, dmodel/hkeys = self.key_projection(keys).view(B, S, H, -1)#一样的计算方法values = self.value_projection(values).view(B, S, H, -1)#H 表示头的数量-1 表示自动计算该维度
  • encoder的注意力用的fullattention。并且用到了掩码

回到pyra的encoder

self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
def forward(self, x):residual = xif self.normalize_before:x = self.layer_norm(x)x = F.gelu(self.w_1(x))x = self.dropout(x)x = self.w_2(x)x = self.dropout(x)x = x + residualif not self.normalize_before:x = self.layer_norm(x)return x
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
#B,seq,3,dmodel
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
#B,seq+pred,dmodel
all_enc = torch.gather(seq_enc, 1, indexes)
##B,seq+pred,dmodel
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
#B,seq,dmodel*3
return seq_enc

总结

x_dec, x_mark_dec, mask=None都没用到

  • 直接进入encoder

重构了encoder和decoder,跟transformer的很不一样

embedding方法跟former一样

encoder的注意力用的fullattention,并且用到了掩码

http://www.lryc.cn/news/478688.html

相关文章:

  • 成绩排序c++
  • 人脸检测之MTCNN算法网络结构
  • 蓝桥杯顺子日期(填空题)
  • Java云HIS医院管理系统源码 病案管理、医保业务、门诊、住院、电子病历编辑
  • 【C++的vector、list、stack、queue用法简单介绍】
  • git中使用tag(标签)的方法及重要性
  • 【专题】2024年文旅微短剧专题研究报告汇总PDF洞察(附原数据表)
  • celery加速爬虫 使用flower 可视化地查看celery的实时监控情况
  • Angular进阶之十:toPromise废弃原因及解决方案
  • python实现RSA算法
  • 可灵开源视频生成数据集 学习笔记
  • 告别软文营销瓶颈!5招助你突破限制,实现宣传效果最大化
  • 秋冬进补防肥胖:辨证施补,健康过冬不增脂
  • uniapp radio单选
  • 通熟易懂地讲解GCC和Makefile
  • Java Agent使用
  • selenium 点击元素报错element not interactable
  • 【大数据技术基础 | 实验七】HBase实验:部署HBase
  • Android进程保活,lmkd杀进程相关
  • SDL 播放PCM
  • 基于MPPT最大功率跟踪的光伏发电蓄电池控制系统simulink建模与仿真
  • 深入解析Vue3:从入门到实战(详细版)
  • Pr 视频效果:ASC CDL
  • C++ --- Socket套接字的使用
  • MG协议转换器:制氢行业的数字桥梁
  • 人工智能技术的未来:变革生活与工作的潜力
  • D60【python 接口自动化学习】- python基础之数据库
  • 零基础大龄程序员如何转型AI大模型,系统学习路径与资源推荐!!
  • vue3+vant实现使用van-picker实现三级级联菜单展示(含递归遍历)
  • oracle-函数-grouping sets(x1,x2,x3...)的妙用