当前位置: 首页 > news >正文

u2net网络模型训练自己数据集

单分类
下载项目源码
项目源码

准备数据集
将json转为mask
json_to_dataset.py

import cv2
import json
import numpy as np
import os
import sys
import globdef func(file):with open(file, mode='r', encoding="utf-8") as f:configs = json.load(f)shapes = configs["shapes"]png_class = np.zeros((configs["imageHeight"], configs["imageWidth"], 1), np.uint8)png_other = np.zeros((configs["imageHeight"], configs["imageWidth"], 1), np.uint8)for shape in shapes:label = shape['label']if label == 'class':cv2.fillPoly(png_class, [np.array(shape["points"], np.int32)], (255))else:cv2.fillPoly(png_other, [np.array(shape["points"], np.int32)], (255))png = png_class - png_otherreturn pngif __name__ == "__main__":json_dir = "image"save_dir = 'image/masks'for file in os.listdir(json_dir):print('***************', file)if file.endswith(".json"):# 在这里添加后续步骤png = func(os.path.join(json_dir, file))print(png.shape)save_path = save_dir + '/' + os.path.splitext(file)[0] + ".png"cv2.imwrite(save_path, png)print('***************', save_path)

创建文件路径

train_data|__DUTS|__DUTS-TR|__im_aug			# mask图|__gt_aug			# 原图|__DUTS-TE|__im_aug			# mask图|__gt_aug			# 原图

训练
修改u2net_train.py

# ------- 2. set the directory of training dataset --------model_name = 'u2net' #'u2netp'											# 选择模型# 修改为对应的数据集路径
data_dir = os.path.join(os.getcwd(), 'train_data' + os.sep)
tra_image_dir = os.path.join('DUTS', 'DUTS-TR', 'im_aug' + os.sep)
tra_label_dir = os.path.join('DUTS', 'DUTS-TR', 'gt_aug' + os.sep)image_ext = '.jpg'
label_ext = '.png'model_dir = os.path.join(os.getcwd(), 'saved_models', model_name + os.sep)

alpha_export.py为转换模型,u2net_test.py模型预测

多分类
准备数据集
将json转为mask,由于是多分类,每一个分类一个rgb值,例如分类1(1,1,1),分类2(2,2,2),分类3(3,3,3);

import cv2
import json
import numpy as np
import os
import sys
import globdef func(file):with open(file, mode='r', encoding="utf-8") as f:configs = json.load(f)shapes = configs["shapes"]class1 = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)class2 = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)class3 = np.zeros((configs["imageHeight"], configs["imageWidth"], 3), np.uint8)for shape in shapes:label = shape['label']if label == 'class1':cv2.fillPoly(class1, [np.array(shape["points"], np.int32)], (1, 1, 1))elif label == 'class2':cv2.fillPoly(class2, [np.array(shape["points"], np.int32)], (2, 2, 2))elif label == 'class3':cv2.fillPoly(class3, [np.array(shape["points"], np.int32)], (3, 3, 3))return class1 + class2 + class3if __name__ == "__main__":json_dir = "./train_data/labels_json"save_dir = './train_data/masks'for file in os.listdir(json_dir):print('***************', file)if file.endswith(".json"):# 在这里添加后续步骤png = func(os.path.join(json_dir, file))print(png.shape)save_path = save_dir + '/' + os.path.splitext(file)[0] + ".png"cv2.imwrite(save_path, png)print('***************', save_path)

数据集文件夹

datasets|__train_data|__images		# 原图|__1.jpg|__2.jpg|__3.jpg|__masks		# mask图|__1.png|__2.png|__3.png|__test_data|__images|__1.jpg|__2.jpg|__3.jpg|__masks|__1.png|__2.png|__3.png|__my_results_2|__predeict_lables|__predict_masks

训练

# ------- 2. set the directory of training process --------
model_name = 'u2net'  # 'u2net'							# 选择模型
model_dir = os.path.join('saved_models', model_name + os.sep)# 图片的文件类型
image_ext = '.jpg'
label_ext = '.png'# train阶段
# 数据集路径
data_dir = os.path.join("datasets/train_data" + os.sep)
# 原始图片路径
tra_image_dir = os.path.join('images' + os.sep)
# 图片的标签路径
tra_label_dir = os.path.join('masks' + os.sep)# val阶段
# 数据集类别
num_classes = 4			# 背景+类别数
name_classes = ["background", "class1", "class2","class3"]
# 原始图片路径
images_path = "datasets/test_data/images/"
# 图片的标签路径
gt_dir = "datasets/test_data/masks/"
# 存放推理结果图片的路径
pred_dir = "datasets/test_data/predict_masks/"
predict_label = "datasets/test_data/predict_labels/"
# 存放 miou 计算结果的 图片
miou_out_path = "miou_out_tab"
# 预训练权重
seg_pretrain_u2netp_path = 'saved_models/pretrain_model/segm_u2net.pth'
if not os.path.exists(model_dir):os.makedirs(model_dir)epoch_num = 100
batch_size = 2
train_num = 0
val_num = 0

执行u2net_train.py,开始训练

预测

if __name__ == "__main__":#   miou_mode用于指定该文件运行时计算的内容#   miou_mode为0代表整个miou计算流程,包括获得预测结果、计算miou。#   miou_mode为1代表仅仅获得预测结果。#   miou_mode为2代表仅仅计算miou。#   分类个数+1、如2+1# num_classes = 3# name_classes = ["background", "green", "red"]num_classes = 4name_classes =["background", "class1", "class2","class3"]# 原始图片路径images_path = "datasets_ButtonCell/test_data/images/"# 图片的标签路径gt_dir = "datasets_ButtonCell/test_data/masks/"# 存放推理结果图片的路径pred_dir = "datasets_ButtonCell/test_data/predict_masks/"predict_label = "datasets_ButtonCell/test_data/predict_labels/"# 存放 miou 计算结果的 图片miou_out_path = "miou_out"# 模型路径model_dir = './saved_models/u2netp/u2netp.pth'eval_print_miou(num_classes, name_classes, images_path, gt_dir, pred_dir, predict_label, miou_out_path, model_dir)

执行u2net_val.py,开始训练

torch2onnx.py进行模型转换

http://www.lryc.cn/news/452105.html

相关文章:

  • 登录功能开发 P167重点
  • 数据架构图:从数据源到数据消费的全面展示
  • useEffect 与 useLayoutEffect 的区别
  • OPENCV判断图像中目标物位置及多目标物聚类
  • 分布式理论:拜占庭将军问题
  • 从零开始Ubuntu24.04上Docker构建自动化部署(三)Docker安装Nginx
  • 阿里云 SAE Web:百毫秒高弹性的实时事件中心的架构和挑战
  • 人口普查管理系统基于VUE+SpringBoot+Spring+SpringMVC+MyBatis开发设计与实现
  • 使用VBA快速将文本转换为Word表格
  • 力扣题解1870
  • D3.js数据可视化基础——基于Notepad++、IDEA前端开发
  • 在Robot Framework中Run Keyword If的用法
  • 虚拟机ip突然看不了了
  • LeetCode[中等] 763. 划分字母区间
  • Java LeetCode每日一题
  • 数据结构--集合框架
  • Win10鼠标总是频繁自动失去焦点-非常有效-重启之后立竿见影
  • 智能涌现|迎接智能时代,算力产业重构未来
  • 关于HTML 案例_个人简历展示01
  • 【前端开发入门】css快速入门
  • java中创建不可变集合
  • D25【 python 接口自动化学习】- python 基础之判断与循环
  • HTTP1.0和HTTP1.1有什么区别
  • 卡夫卡的理解
  • 基础算法之滑动窗口--Java实现(上)--LeetCode题解:长度最小的子数组-无重复字符的子串-最大连续1的个数III-将x减到0的最小操作数
  • Linux -- 文件系统(文件在磁盘中的存储)
  • 微服务(Microservices),服务网格(Service Mesh)以及无服务器运算Serverless简单介绍
  • 【AIGC】AI时代的数据安全:使用ChatGPT时的自查要点
  • 什么是区块链桥?
  • 机器学习框架