当前位置: 首页 > news >正文

102005

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 设定使用的 GPUimport tensorflow as tf
from dataset import generate_data
import numpy as np
from model import enhancednet# 检查 TensorFlow 是否可以识别 GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:try:# 限制 TensorFlow 只使用第0号 GPUtf.config.set_visible_devices(gpus[0], 'GPU')tf.config.experimental.set_memory_growth(gpus[0], True)print(f"Using GPU: {gpus[0].name}")except RuntimeError as e:print(e)
else:print("No GPU available, using CPU.")# 列出所有物理设备
print("All physical devices:")
for device in tf.config.list_physical_devices():print(device)print("\nAvailable GPU devices:")
for gpu in tf.config.list_physical_devices('GPU'):print(gpu)# 设置参数
image_rows = 128
image_cols = 256
filename = 'detached_data.mat'# 生成和准备数据
train_data, train1_data, label_data = generate_data(filename)
train_data = np.array(train_data, dtype=float).reshape(-1, image_rows, image_cols, 1)
train1_data = np.array(train1_data, dtype=float).reshape(-1, image_rows, image_cols, 1)# 确保数据类型为 float32
train_data = train_data.astype('float32')
train1_data = train1_data.astype('float32')# 打印数据形状以确认
print("Train data shape:", train_data.shape)
print("Train1 data shape:", train1_data.shape)# 创建模型
model = enhancednet()# 编译模型(假设 enhancednet 已经包含编译逻辑,可根据需要调整)
model.compile(optimizer='adam',loss='mean_squared_error',  # 示例损失函数,依据具体任务调整metrics=['accuracy'])# 训练模型,并指定设备
print("\nStarting training on GPU...")
history = model.fit(train_data, train1_data,batch_size=32,epochs=100,verbose=2,shuffle=True,validation_split=0.1)# 保存模型
model.save('enhanced_model.h5')
print("\nModel saved to 'enhanced_model.h5'")
http://www.lryc.cn/news/464769.html

相关文章:

  • Cisco ACI环境给Leaf配置OOB带外管理IP方法
  • 免费送源码:Java+B/S+MySQL springboot电影推荐系统 计算机毕业设计原创定制
  • 数据清洗(脚本)
  • jmeter中发送post请求遇到的问题
  • Java中使用protobuf
  • 2020款Macbook Pro A2251无法充电无法开机定位及修复
  • Spring Cloud --- 引入Gateway网关
  • ESP32-C3实现定时器的启停(Arduino IDE)
  • centos升级g++使其支持c++17
  • Pytest日志收集器配置
  • Morris算法(大数据作业)
  • TCP/IP协议 【三次握手】过程简要描述
  • docker 数据管理,数据持久化详解 二 数据卷容器
  • Logrotate:Linux系统日志轮转和管理的实用指南
  • 八股面试3(自用)
  • 【微服务】springboot3 集成 Flink CDC 1.17 实现mysql数据同步
  • 【Android】浅析OkHttp(1)
  • Generate-on-Graph
  • 学习笔记——交换——STP(生成树)简介
  • 【Linux从入门到精通一】操作系统概述与Linux初识
  • Git 深度解析 —— 从基础到进阶
  • PCIE-变量总结
  • 【iOS】AFNetworing初步学习
  • 【数据结构】堆的创建
  • Linux下Git操作
  • 缺失d3dx9_42.dll如何修复,d3dx9_42.dll故障的6种修复方法分享
  • 深入理解Android WebView的加载流程与事件回调
  • 机器视觉相机自动对焦算法
  • StarTowerChain:开启去中心化创新篇章
  • SpringCloudStream使用StreamBridge实现延时队列