当前位置: 首页 > news >正文

yolov8-pos/yolov11-pos 训练

一、前言

        最近需要做草莓识别以及摘取,需要对草莓识别,并获取草莓头和草莓尖这2个关键点,故关注到可使用yolov8-pos / yolov11-pos 解决,对 yolov11-pos 进行了实验。

        假设有 c个 类别, n个特征点 yolov11-pos 单张图在640尺度下输出特征图是 1*8400*(4+c+n*3),相当于有8400个长度为(4+c+n*3)的特征, 4表示预测框的中心点x,y坐标,宽、高;c表示 c个类的预测概率,n*3是 n个点 , 每个点 有 x y  属性(无效0,不可见1,可见2) , 这3个值。

二、数据集制作

       (1)识别对象

        我的识别对象主要有草莓熟果、果梗,这2类对象,关键点数量为2,草莓的头和尖,果梗虽然不识别关键点,但是也必须有2个点,这里我就取边框的  (中心点x坐标,边框的y坐标top坐标)和(中心点x坐标,边框的y坐标bottom坐标 ) 像是下面的

        

        这个是labelme 画出来的,梗的2个点可以直接生成,所以在标注的时候可以不画出来

        (2)labelme 软件画标注

         采集足够数据后,需要对这2类的矩形框和关键点画出来,使用labelme 软件打开这个文件夹。显示如下

        然后画框的时候注意,为了使框和关键点是对应的,便于后续转换json文件为 txt文件,需要每次画一个框后,把2个关键点也画一画;

        画框是:编辑-》创建矩形

        画点是:编辑-》创建控制点

        画一个框和对应2个点,在后边可显示出顺序关系,caomei矩形框之后是p1和p2两个点

     

        标注完成后,labelme软件是生成和图片同名字的json文件。

        (3)把所有json文件给转换成yolo格式的txt文件

        yolo格式的txt文件中一行是一个识别对象,一个对象是 1个框2个点,共有 (1+4+n*3)个数据,也即是 (1+4+2*3)=11个数据,跟前面讲输出维度(4+c+n*3)有些点点不同,但本质一致。

        (1+4+2*3),也即 第1个数是类别号,1个类别的类别值就是0 ,2个类别就是0或1 ;3个类别就是0 或1或2,以此类推;4就是边框归一化的中心点 x ,y坐标 宽 和 高 坐标 ;2*3, 也即2个点,每个点有3个数值, x , y , 属性(无效0,不可见1,可见2),无效就是这个点不在图像中了,不可见就是点在图像中,但是被遮挡,可见就是点在图像中,且未被遮挡。

        在我这都是默认可见 2 了。现在需要编写脚本转换一下,我把图片和相应json文件放在 all文件夹,需要生成 txt 文件在 pos_txt  , 脚本如下:

import json
import os
import copyfrom tool import getXyXydef read_jsonfile(  path):with open(path, "r", encoding='utf-8') as f:return json.load(f)def check_lists(a, b, c):# 确保所有列表长度一致if len(a) != len(b) or len(a) != len(c):raise ValueError("所有列表的长度必须相同")# 检查每个位置上的数字是否符合要求 (例如,数字之间相差1)for i in range(len(a)):if not (b[i] == a[i] + 1 and c[i] == a[i] + 2):raise ValueError(f"第 {i + 1} 个位置的元素不符合规律:a={a[i]}, b={b[i]}, c={c[i]}")def checkShapeData(  shape_data ):caomei_id = []for  d_i ,  d in  enumerate(shape_data  ) :if d["label"] == 'caomei':caomei_id.append( d_i  )p1_id = []for  d_i ,  d in  enumerate(shape_data  ) :if d["label"] == 'p1':p1_id.append( d_i  )p2_id = []for  d_i ,  d in  enumerate(shape_data  ) :if d["label"] == 'p2':p2_id.append( d_i  )try:check_lists(caomei_id, p1_id, p2_id)except ValueError as e:print(f"错误: {e}")return caomei_id,p1_id,p2_iddef genPosTxt(pos_all, txt_path, im_h, im_w):f_txt = open(txt_path, 'w')for e in pos_all:if len(e) == 3:x1, y1, x2, y2 = e[0]x_center = (x1 + x2) / 2 / im_wy_center = (y1 + y2) / 2 / im_hwidth = (x2 - x1) / im_wheight = (y2 - y1) / im_hif width < 0 or height < 0:print("error!")f_txt.write("%s %s %s %s %s " % (0, round(x_center, 6), round(y_center, 6), round(width, 6), round(height, 6)))for p in e[1:]:p1_x, p1_y = pp1_xx = p1_x / im_wp1_yy = p1_y / im_hf_txt.write("%s %s %s " % (round(p1_xx, 6), round(p1_yy, 6), 2))f_txt.write("\n")else:x1, y1, x2, y2 = e[0]x_center = (x1 + x2) / 2 / im_wy_center = (y1 + y2) / 2 / im_hwidth = (x2 - x1) / im_wheight = (y2 - y1) / im_hf_txt.write("%s %s %s %s %s " % (1, round(x_center, 6), round(y_center, 6), round(width, 6), round(height, 6)))for p in e[1:-1]:p1_x, p1_y = pp1_xx = p1_x / im_wp1_yy = p1_y / im_hf_txt.write("%s %s %s " % (round(p1_xx, 6), round(p1_yy, 6), 2))f_txt.write("\n")f_txt.close()def processJsonFile(  json_path  , txt_path):obj = read_jsonfile(  json_path )shape_data = obj["shapes"]im_h = obj["imageHeight"]im_w = obj["imageWidth"]caomei_id, p1_id, p2_id = checkShapeData(  shape_data )pos_all = []for  d_i ,  shape_dict in  enumerate(shape_data  ) :if d_i in caomei_id:pt_temp = shape_dict["points"]cm = getXyXy( pt_temp )temp = [cm, ]temp.append( shape_data[ d_i+1 ]["points"][0]  )temp.append(shape_data[d_i + 2]["points"][0] )pos_all.append( temp )   #这里是一条记录elif  d_i not in caomei_id and d_i not in p1_id and d_i not in p2_id:stem_pt = shape_data[ d_i  ]["points"]stem_box = getXyXy( stem_pt )temp = [stem_box, ]pt1 = [  (stem_box[0]+stem_box[2])/2-1 , stem_box[1]+1   ]  # x 中心点, 和 toppt2 = [  (stem_box[0]+stem_box[2])/2+1 , stem_box[3]-1   ]  # x 中心点, 和 bottomtemp.append(  pt1 )temp.append(  pt2  )   #这里也要保持9个temp.append( 1 )pos_all.append( temp )genPosTxt(pos_all, txt_path, im_h, im_w)if  __name__=="__main__":img_r = r"all"save_r = "pos_txt"for json_n in os.listdir(  img_r ):if "json" not in json_n:continuejson_path = os.path.join( img_r ,  json_n )save_path = os.path.join( save_r ,  json_n.replace("json" , "txt") )processJsonFile(  json_path  , save_path )

        在 pos_txt 文件夹中生成所有图片对应的yolo格式 txt 标注文件

(4)制作 images 和 labels 文件夹

        在all 同级目录,创建  images 和 labels 文件夹 ,每个文件夹里分别创建 train 文件夹和 val文件夹。运行下面脚本可以把all 中图片 和  pos_txt 中标注文件 ,数量按8比2拷贝过来,完成数据集制作

# -*- coding: utf-8 -*-
"""
Created on Tue Feb 27 17:10:01 2024@author: WIN10
"""import os
import shutil
import randompic_src = r"all"
txt_src = r"pos_txt"thr = 0.8
pic_format = "jpg"pic_dst = r"./"pic_dst_train = os.path.join(  pic_dst,  "images\\train" )
pic_dst_val = os.path.join(  pic_dst,  "images\\val" )txt_dst_train = os.path.join(  pic_dst, "labels\\train" )
txt_dst_val = os.path.join(  pic_dst, "labels\\val" )def main():txt_name_l = []# 遍历文件夹中的所有文件 , 是根据标注文件 txt 去找图片for filename in os.listdir(txt_src):if "class" in filename:continueif filename.endswith('.txt'):  # 找到以 .txt 结尾的文件txt_name_l.append(  filename )#打乱排序random.shuffle(  txt_name_l )train_n = int( thr * len( txt_name_l ))for txt_n in txt_name_l[:train_n]:pic_src_p0 = os.path.join(  pic_src , txt_n.replace("txt" ,  pic_format )  )txt_src_p0 = os.path.join( txt_src , txt_n  )pic_train = os.path.join(  pic_dst_train , txt_n.replace("txt" ,  pic_format )  )txt_train = os.path.join(  txt_dst_train , txt_n  )shutil.copyfile(pic_src_p0  ,  pic_train )shutil.copyfile(txt_src_p0, txt_train)for txt_n in txt_name_l[train_n:]:pic_src_p0 = os.path.join(  pic_src , txt_n.replace("txt" ,  pic_format )  )txt_src_p0 = os.path.join( txt_src , txt_n  )pic_val = os.path.join(  pic_dst_val , txt_n.replace("txt" ,  pic_format )  )txt_val = os.path.join(  txt_dst_val , txt_n  )shutil.copyfile(pic_src_p0  ,  pic_val )shutil.copyfile(txt_src_p0, txt_val )if __name__ == "__main__":main()

三、训练

        (1)下载好官方yolo11代码文件,修改  ultralytics/cfg/models/11/yolo11-pose.yaml 

nc 类别改2  kpt_shape  改为 [2,3]

      (2)修改 ultralytics/cfg/datasets/caomei_pose.yaml , 自己创建个

内容是:

# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license# Dogs dataset http://vision.stanford.edu/aditya86/ImageNetDogs/ by Stanford
# Documentation: https://docs.ultralytics.com/datasets/pose/dog-pose/
# Example usage: yolo train data=dog-pose.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── dog-pose  ← downloads here (337 MB)# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: D:/Project/StrawberryRecog/datasets/PlasticStrawberriesPose3 # dataset root dir
train: images/train # train images (relative to 'path') 6773 images
val: images/val # val images (relative to 'path') 1703 images# Keypoints
kpt_shape: [2,3] # number of keypoints, number of dims (2 for x,y or 3 for x,y,visible)# Classes
names:0: caomei1: stem# Download script/URL (optional)
#download: https://github.com/ultralytics/assets/releases/download/v0.0.0/dog-pose.zip

(3)训练脚本

在官方代码目录下创建 train_pos.py 

import warningswarnings.filterwarnings('ignore')
from ultralytics import YOLOif __name__ == '__main__':# 注意 名字上要是yolo11s-pose.yaml 才表示使用的是s , 否则认为用的是nmodel = YOLO('ultralytics/cfg/models/11/yolo11-pose.yaml')   model.load('weights/yolo11n.pt') # loading pretrain weights ,需要对应好 n,s m l x model.train(data='ultralytics/cfg/datasets/caomei_pose.yaml',cache=False,imgsz=640,epochs=300,batch=16,device='0',name='exp',)path = model.export(format="onnx", opset=12, dynamic=False, half=False)

        最后一行是转换为 onxx 文件,便于部署到C++端或C+端等等

四、训练好的权值文件

        训练完成后,将得到pt格式权值文件,onnx格式权值文件

五、模型openvino部署

        这部分请参考下一篇《yolov8-pos/yolov11-pos openvino C++部署

http://www.lryc.cn/news/591047.html

相关文章:

  • 6、docker network
  • UE5 lumen
  • Linux搭建LAMP环境(CentOS 7 与 Ubuntu 双系统教程)
  • FastAdmin系统框架通用操作平滑迁移到新服务器的详细步骤-优雅草卓伊凡
  • lua(xlua)基础知识点记录二
  • STM32上移植Lua解析器
  • Android15系统实现刷机防呆功能
  • 【JVM】深入理解 JVM 类加载器
  • MySQL如何解决事务并发的幻读问题
  • JVM 内存分配与垃圾回收策略
  • macOS 字体管理全攻略:如何查看已安装字体及常见字体格式区
  • 网络编程7.17
  • JAVA中的Collection集合及ArrayList,LinkedLIst,HashSet,TreeSet和其它实现类的常用方法
  • MyBatis延迟加载(Lazy Loading)之“关联查询”深度解析与实践
  • 【44】MFC入门到精通——MFC 通过Button按钮添加控件变量实现:按下 按钮变色 (比如开关 打开关闭状态) MFC更改button控颜色
  • 数据结构-2(链表)
  • 基于STM32闭环步进电机控制系统设计说明
  • Leaflet地图交互:实现图形与点的同时高亮效果
  • PyTorch生成式人工智能(18)——循环神经网络详解与实现
  • 【Linux基础知识系列】第五十一篇 - Linux文件命名规范与格式
  • Mac 安装及使用sdkman指南
  • Java 大视界 -- Java 大数据在智能交通智能公交站台乘客流量预测与服务优化中的应用(349)
  • Flask+LayUI开发手记(十一):选项集合的数据库扩展类
  • Java 集合框架详解:Collection 接口全解析,从基础到实战
  • 【LeetCode 热题 100】108. 将有序数组转换为二叉搜索树
  • 【Redis 】看门狗:分布式锁的自动续期
  • 如何用Kaggle免费GPU
  • [yotroy.cool] Git 历史迁移笔记:将 Git 项目嵌入另一个仓库子目录中(保留提交记录)
  • 语雀编辑器内双击回车插入当前时间js脚本
  • 【WRFDA第六期】WRFDA 输出文件详述