import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 设定使用的 GPUimport tensorflow as tf
from dataset import generate_data
import numpy as np
from model import enhancednet# 检查 TensorFlow 是否可以识别 GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:try:# 限制 TensorFlow 只使用第0号 GPUtf.config.set_visible_devices(gpus[0], 'GPU')tf.config.experimental.set_memory_growth(gpus[0], True)print(f"Using GPU: {gpus[0].name}")except RuntimeError as e:print(e)
else:print("No GPU available, using CPU.")# 列出所有物理设备
print("All physical devices:")
for device in tf.config.list_physical_devices():print(device)print("\nAvailable GPU devices:")
for gpu in tf.config.list_physical_devices('GPU'):print(gpu)# 设置参数
image_rows = 128
image_cols = 256
filename = 'detached_data.mat'# 生成和准备数据
train_data, train1_data, label_data = generate_data(filename)
train_data = np.array(train_data, dtype=float).reshape(-1, image_rows, image_cols, 1)
train1_data = np.array(train1_data, dtype=float).reshape(-1, image_rows, image_cols, 1)# 确保数据类型为 float32
train_data = train_data.astype('float32')
train1_data = train1_data.astype('float32')# 打印数据形状以确认
print("Train data shape:", train_data.shape)
print("Train1 data shape:", train1_data.shape)# 创建模型
model = enhancednet()# 编译模型(假设 enhancednet 已经包含编译逻辑,可根据需要调整)
model.compile(optimizer='adam',loss='mean_squared_error', # 示例损失函数,依据具体任务调整metrics=['accuracy'])# 训练模型,并指定设备
print("\nStarting training on GPU...")
history = model.fit(train_data, train1_data,batch_size=32,epochs=100,verbose=2,shuffle=True,validation_split=0.1)# 保存模型
model.save('enhanced_model.h5')
print("\nModel saved to 'enhanced_model.h5'")