当前位置: 首页 > news >正文

mediapipe 训练自有图像数据分类

参考:
https://developers.google.com/mediapipe/solutions/customization/image_classifier
https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/customization/image_classifier.ipynb#scrollTo=plvO-YmcQn5g

安装:

pip install mediapipe-model-maker  -i http://mirrors.aliyun.com/pypi/simple --trusted-host mirrors.aliyun.com --use-pep517

版本错误情况

1)RuntimeError: File loading is not yet supported on Windows

其中mediapipe版本要大于等于0.10.0;下图中的要升级;不然后续用python 加载文件会报:

2)ImportError: cannot import name ‘array_record_module’ from ‘array_record.python’ ;参考:https://blog.csdn.net/LQ_001/article/details/130991571;原因:包依赖关系出现问题,原来版本 tensorflow-datasets==4.9.0

pip install tensorflow-datasets==4.8.3

在这里插入图片描述

在这里插入图片描述

1、训练代码

import os
import tensorflow as tf
assert tf.__version__.startswith('2')from mediapipe_model_maker import image_classifierimport matplotlib.pyplot as pltimage_path = os.path.join(os.path.dirname(r"C:\Users\loong\Downloads\mediapipe\flower_photos\flower_photos"), 'flower_photos')   ## down data  :https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz#Review datalabels = []
for i in os.listdir(image_path):if os.path.isdir(os.path.join(image_path, i)):labels.append(i)
print(labels)##plt 
NUM_EXAMPLES = 5for label in labels:label_dir = os.path.join(image_path, label)example_filenames = os.listdir(label_dir)[:NUM_EXAMPLES]fig, axs = plt.subplots(1, NUM_EXAMPLES, figsize=(10,2))for i in range(NUM_EXAMPLES):axs[i].imshow(plt.imread(os.path.join(label_dir, example_filenames[i])))axs[i].get_xaxis().set_visible(False)axs[i].get_yaxis().set_visible(False)fig.suptitle(f'Showing {NUM_EXAMPLES} examples for {label}')plt.show()

在这里插入图片描述

#Create dataset;训练集、测试集data = image_classifier.Dataset.from_folder(image_path)
train_data, remaining_data = data.split(0.8)
test_data, validation_data = remaining_data.split(0.5)## retrain model 训练模型spec = image_classifier.SupportedModels.MOBILENET_V2    ##有几个预训练模型,需要联网下载
hparams = image_classifier.HParams(export_dir="exported_model")  ##指定模型保存位置
options = image_classifier.ImageClassifierOptions(supported_model=spec, hparams=hparams)
model = image_classifier.ImageClassifier.create(train_data = train_data,validation_data = validation_data,options=options,
)## 验证模型
loss, acc = model.evaluate(test_data)
print(f'Test loss:{loss}, Test accuracy:{acc}')##保存模型
model.export_model()

在这里插入图片描述

在这里插入图片描述
默认训练是10epcos
在这里插入图片描述

查看训练tebsorboard:
注意ValueError: Duplicate plugins for name projector错误,参考https://blog.csdn.net/weixin_44966641/article/details/123292034;我这里是换了个conda环境重新安装个新的tensorflow解决

tensorboard --logdir=.

日志存放默认地址
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

##模型压缩
from mediapipe_model_maker import quantizationquantization_config = quantization.QuantizationConfig.for_int8(train_data)
model.export_model(model_name="model_int8.tflite", quantization_config=quantization_config)

从8M缩小到3M左右
在这里插入图片描述

2、加载推理

参考:https://blog.csdn.net/weixin_42357472/article/details/131322076

import mediapipe as mpBaseOptions = mp.tasks.BaseOptions
ImageClassifier = mp.tasks.vision.ImageClassifier
ImageClassifierOptions = mp.tasks.vision.ImageClassifierOptions
VisionRunningMode = mp.tasks.vision.RunningModeoptions = ImageClassifierOptions(base_options=BaseOptions(model_asset_path=r"C:\User**ediapipe\model.tflite"),max_results=5,running_mode=VisionRunningMode.IMAGE)   ##加载模型classifier = ImageClassifier.create_from_options(options)# Load the input image from an image file.
mp_image = mp.Image.create_from_file(r"C:\Users\loong\Downloads\sun2.jpg")# Perform image classification on the provided single image.
classification_result = classifier.classify(mp_image)
classification_result

在这里插入图片描述
在这里插入图片描述

http://www.lryc.cn/news/211768.html

相关文章:

  • 【pytorch】torch.gather()函数
  • Mac 安装psycopg2,报错Error: pg_config executable not found.
  • 域名系统 DNS
  • Vue $nextTick 模板解析后在执行的函数
  • VBA技术资料MF76:将自定义颜色添加到调色板
  • zilong-20231030
  • 目标检测算法发展史
  • React 生成传递给无障碍属性的唯一 ID
  • 十种排序算法(1) - 准备测试函数和工具
  • IRF联动 BFD-MAD
  • 双向链表的初步练习
  • IDE的组成
  • 项目解读_v2
  • 杀毒软件哪个好,杀毒软件有哪些
  • Ubuntu上安装配置Nginx
  • C++之string
  • 多线程---单例模式
  • SpringBoot相比于Spring的优点(自动配置和依赖管理)
  • SAP SPAD新建打印纸张
  • C# 图解教程 第5版 —— 第11章 结构
  • 车载电子电器架构 —— 基于AP定义车载HPC
  • Redis原理-IO模型和持久化
  • PID控制示例
  • GoLand GC(垃圾回收机制)简介及调优
  • AI:40-基于深度学习的森林火灾识别
  • 37基于MATLAB平台的图像去噪,锐化,边缘检测,程序已调试通过,可直接运行。
  • 通过Metasploit+Ngrok穿透内网长期维持访问外网Android设备
  • STM32 CubeMX配置USB HID功能,及安装路径
  • 【错误解决方案】ModuleNotFoundError: No module named ‘transformers‘
  • Mac 配置环境变量