当前位置: 首页 > news >正文

YOLOv8添加MobileViTv3模块(代码+free)

目录

一、理由

二、方法

(1)导入MobileViTv3模块

(2)在ultralytics/nn/tasks.py的函数parse_model中修改

(3)在yaml配置文件中写入

(4)开始训练,先把其他梯度关闭,保留新加的模块的梯度。

代码已在GitHub上传,链接:yolov8_vit


一、理由

        MobileViTv3是一种为移动设备优化的轻量级视觉Transformer架构,它结合了卷积神经网络(CNN)和视觉Transformer(ViT)的特点,以创建适合移动视觉任务的轻量级模型。

二、方法

(1)导入MobileViTv3模块

在ultralytics/nn创建vit文件夹,文件夹内放MobileViTv3以及需要的包。MobileViTv3模块如下:

import numpy as np
from torch import nn, Tensor
import math
import torch
from torch.nn import functional as F
from typing import Optional, Dict, Tuple, Union, Sequence
from mobilevit_v2_block import MobileViTBlockv2 as MbViTBkV2class MbViTV3(MbViTBkV2):def __init__(self,in_channels: int,attn_unit_dim: int,patch_h: Optional[int] = 2,patch_w: Optional[int] = 2,ffn_multiplier: Optional[Union[Sequence[Union[int, float]], int, float]] = 2.0,n_attn_blocks: Optional[int] = 2,attn_dropout: Optional[float] = 0.0,dropout: Optional[float] = 0.0,ffn_dropout: Optional[float] = 0.0,conv_ksize: Optional[int] = 3,attn_norm_layer: Optional[str] = "layer_norm_2d",enable_coreml_compatible_fn: Optional[bool] = False,) -> None:super(MbViTV3, self).__init__(in_channels, attn_unit_dim)self.enable_coreml_compatible_fn = enable_coreml_compatible_fnif self.enable_coreml_compatible_fn:# we set persistent to false so that these weights are not part of model's state_dictself.register_buffer(name="unfolding_weights",tensor=self._compute_unfolding_weights(),persistent=False,)cnn_out_dim = attn_unit_dimself.conv_proj = nn.Conv2d(2 * cnn_out_dim, in_channels, 1, 1)def forward_spatial(self, x: Tensor, *args, **kwargs) -> Tensor:x = self.resize_input_if_needed(x)fm_conv = self.local_rep(x)# convert feature map to patchesif self.enable_coreml_compatible_fn:patches, output_size = self.unfolding_coreml(fm_conv)else:patches, output_size = self.unfolding_pytorch(fm_conv)# learn global representations on all patchespatches = self.global_rep(patches)# [B x Patch x Patches x C] --> [B x C x Patches x Patch]if self.enable_coreml_compatible_fn:fm = self.folding_coreml(patches=patches, output_size=output_size)else:fm = self.folding_pytorch(patches=patches, output_size=output_size)# MobileViTv3: local+global instead of only globalfm = self.conv_proj(torch.cat((fm, fm_conv), dim=1))# MobileViTv3: skip connectionfm = fm + xreturn fmif __name__ == '__main__':from thop import profile  ## 导入thop模块model = MbViTV3(320, 160, enable_coreml_compatible_fn=False)input = torch.randn(1, 320, 44, 84)#flops, params = profile(model, inputs=(input,))outpus = model.forward_spatial(input)print('flops')  ## 打印计算量# print('params', params)  ## 打印参数量

(2)在ultralytics/nn/tasks.py的函数parse_model中修改

def parse_model(d, ch, verbose=True):  # model_dict, input_channels(3)# Parse a YOLO model.yaml dictionaryif verbose:LOGGER.info(f"\n{'':>3}{'from':>20}{'n':>3}{'params':>10}  {'module':<45}{'arguments':<30}")nc, gd, gw, act = d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()if verbose:LOGGER.info(f"{colorstr('activation:')} {act}")  # printlayers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, args.......elif m in {MbViTV3}:c2 = args[0].......

(3)在yaml配置文件中写入

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 2  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2        320*320*64- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4       160*160*128- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8       80*80*256- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16      40*40*512- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32     20*20*1024- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9              20*20*1024# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 10- [[-1, 6], 1, Concat, [1]]                  # 11- [-1, 3, C2f, [512]]                        # 12                 40*40*512- [-1, 1, nn.Upsample, [None, 2, 'nearest']] # 13- [[-1, 4], 1, Concat, [1]]                  # 14- [-1, 3, C2f, [256]]                        # 15 (P3/8-small)    44*84*320- [-1, 1, MbViTV3, [320, 160]]               # 16- [-1, 1, Conv, [256, 3, 2]]                 # 17- [[-1, 12], 1, Concat, [1]]                 # 18- [-1, 3, C2f, [512]]                        # 19 (P4/16-medium)  40*40*512- [-1, 1, Conv, [512, 3, 2]]                # 20- [[-1, 9], 1, Concat, [1]]                 # 21- [-1, 3, C2f, [1024]]                      # 22 (P5/32-large)  20*20*1024- [[16, 19, 22], 1, Detect, [nc]]           # 23

(4)开始训练,先把其他梯度关闭,保留新加的模块的梯度。

import os
from ultralytics import YOLO
import subprocess
from ultralytics.nn.vit.Vit import MbViTV3
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def add_vit(model):for name, param in model.model.named_parameters():stand = name[6:8]vit_ls = ['16']if stand in vit_ls:param.requires_grad = Trueelse:param.requires_grad = Falsefor name, param in model.model.named_parameters():if param.requires_grad:print(name)return modeldef main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8x.yaml').load('/root/autodl-tmp/yolov8x.pt')model = YOLO(r'yolov8x_vit.yaml').load('runs/detect/vit/weights/vit.pt')model = add_vit(model)model.train(data="data.yaml", imgsz=640, epochs=50, batch=10, device=0, workers=0)
if __name__ == '__main__':main()

————————————over————————————

http://www.lryc.cn/news/423157.html

相关文章:

  • 从概念到落地:全面解析DApp项目开发的核心要素与未来趋势
  • 仓颉编程入门 -- 泛型概述 , 如何定义泛型函数
  • SOC估算方法之(OCV-SOC+安时积分法)
  • 指针(下)
  • C# 浅谈IEnumerable
  • mmdebstrap:创建 Debian 系统 chroot 环境的利器 ️
  • 【Linux SQLite数据库】一、SQLite交叉编译与移植
  • 每天写两道(数组篇)移除元素、
  • Unity 使用 NewtonSoft Json插件报错
  • k8s 部署 Mysqld_exporter 以及添加告警规则
  • 基于STM32开发的智能农业环境监测系统
  • 【SQL】平均售价
  • 存储器与CPU的连接
  • unity--webgl 访问本地index.html
  • 慢慢欣赏DPDK RTE_MAX_ETHPORTS的定义
  • Java Nacos与Gateway的使用
  • 前端项目中的Server-sent Events(SSE)项目实践及其与websocket的区别
  • 《老俞闲话|唯爱和热情不可辜负》读后感
  • C语言 ——— 在杨氏矩阵中查找具体的某个数
  • DAI-Net: 基于对偶自适应交互网络的药物推荐算法
  • haproxy高级功能及配置
  • 【前端】NodeJS:记账本案例优化(MongoDB数据库)
  • Padding Mask;Sequence Mask;为什么如果没有适当的掩码机制,解码器在生成某个位置的输出时,可能会“看到”并错误地利用该位置之后的信息
  • 派森学长带你学python—字典
  • 如何设置 Visual Studio Code 的滚轮缩放功能
  • Python模拟退火算法
  • C语言典型例题36
  • 实现高亮的全文分页检索
  • 【buildroot与yocto区别】
  • 原创音乐小程序的设计