当前位置: 首页 > news >正文

第T8周:猫狗识别

  • >- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
    >- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

🍺 要求:

  1. 了解model.train_on_batch()并运用
  2. 了解tqdm,并使用tqdm实现可视化进度条

🏡 我的环境:

  • 语言环境:Python3.6.5
  • 编译器:Jupyter Notebook
  • 深度学习环境:TensorFlow2.4.1

1. 设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]],"GPU")# 打印显卡信息,确认GPU可用
print(gpus)

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2. 导入数据

3. 加载数据

使用image_dataset_from_directory方法将磁盘中的数据加载到tf.data.Dataset中 

Found 3400 files belonging to 2 classes.
Using 2720 files for training.

我们可以通过class_names输出数据集的标签。标签将按字母顺序对应于目录名称。

['cat', 'dog'] 

4.配置数据集

  • shuffle() : 打乱数据,关于此函数的详细介绍可以参考:https://zhuanlan.zhihu.com/p/42417456
  • prefetch() :预取数据,加速运行,其详细介绍可以参考我前两篇文章,里面都有讲解。
  • cache() :将数据集缓存到内存当中,加速运行

5.可视化数据

6.构建VG-16网络 

7.编译

model.compile(optimizer="adam",
              loss     ='sparse_categorical_crossentropy',
              metrics  =['accuracy'])

8.训练模型

from tqdm import tqdm
import tensorflow.keras.backend as Kepochs = 10
lr     = 1e-4# 记录训练数据,方便后面的分析
history_train_loss     = []
history_train_accuracy = []
history_val_loss       = []
history_val_accuracy   = []for epoch in range(epochs):train_total = len(train_ds)val_total   = len(val_ds)"""total:预期的迭代数目ncols:控制进度条宽度mininterval:进度更新最小间隔,以秒为单位(默认值:0.1)"""with tqdm(total=train_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=1,ncols=100) as pbar:lr = lr*0.92K.set_value(model.optimizer.lr, lr)for image,label in train_ds:   """训练模型,简单理解train_on_batch就是:它是比model.fit()更高级的一个用法想详细了解 train_on_batch 的同学,可以看看我的这篇文章:https://www.yuque.com/mingtian-fkmxf/hv4lcq/ztt4gy"""history = model.train_on_batch(image,label)train_loss     = history[0]train_accuracy = history[1]pbar.set_postfix({"loss": "%.4f"%train_loss,"accuracy":"%.4f"%train_accuracy,"lr": K.get_value(model.optimizer.lr)})pbar.update(1)history_train_loss.append(train_loss)history_train_accuracy.append(train_accuracy)print('开始验证!')with tqdm(total=val_total, desc=f'Epoch {epoch + 1}/{epochs}',mininterval=0.3,ncols=100) as pbar:for image,label in val_ds:      history = model.test_on_batch(image,label)val_loss     = history[0]val_accuracy = history[1]pbar.set_postfix({"loss": "%.4f"%val_loss,"accuracy":"%.4f"%val_accuracy})pbar.update(1)history_val_loss.append(val_loss)history_val_accuracy.append(val_accuracy)print('结束验证!')print("验证loss为:%.4f"%val_loss)print("验证准确率为:%.4f"%val_accuracy)

 9.总结

tqdm是一个快速、可扩展的Python进度条库,它提供了一种简单而直观的方式来跟踪代码的执行进度。tqdm的主要功能是在长时间运行的循环中添加一个进度提示信息。用户只需将任意的迭代器封装为tqdm(iterator),即可实现进度可视化,这非常适合在数据处理、机器学习训练等需要长时间运行的任务中使用。

http://www.lryc.cn/news/470830.html

相关文章:

  • 第十七周:机器学习
  • 算法4之链表
  • 掌握未来技术:KVM虚拟化安装全攻略,开启高效云端之旅
  • 挖矿病毒的处理
  • JVM(HotSpot):GC之G1垃圾回收器
  • appium文本输入的多种形式
  • springboot095学生宿舍信息的系统--论文pf(论文+源码)_kaic
  • 使用SQL在PostGIS中创建各种空间数据
  • ArkTS 如何适配手机和平板,展示不同的 Tabs 页签
  • Docker下载途径
  • Windows: 如何实现CLIPTokenizer.from_pretrained`本地加载`stable-diffusion-2-1-base`
  • MySQL 9从入门到性能优化-慢查询日志
  • ARM学习(33)英飞凌(infineon)PSOC 6 板子学习
  • 华为原生鸿蒙操作系统的发布有何重大意义和影响:
  • API 接口:连接生活与商业的数字桥梁
  • IEC101 JAVA开发记录
  • 降压恒压150V供电 负载固定5V 持续0.6A电动车仪表供电芯片SL3150H
  • QT 从ttf文件中读取图标
  • JS动态调用变量
  • django restful API
  • 在xml 中 不等式 做转义处理的问题
  • python——文件存储与写入path
  • AI 提示词(Prompt)入门 :ChatGPT 4.0 高级功能指南
  • C++:模板
  • 假如浙江与福建合并为“浙福省”
  • AI图片生成3D物体和2D视频提取3D动画
  • Android 应用包名的定义 pm list packages查询的包名
  • 递归相关练习
  • 租房市场新动力:基于Spring Boot的管理系统
  • 基于Python的B站视频数据分析与可视化