基于Tensorflow2.15的图像分类系统
下图所示的是一个图像分类系统,理论上也支持其他场景的图像分类需求,以花卉分类为例,可在界面上选择数据集,自动化划分数据集,配置训练时的迭代次数 学习率等超参数,即可进行训练,训练完成后可对模型进行测试,输出混淆矩阵,支持单站图片预测和批量预测,核心代码如下所示:
目前界面还在完善中,请各位看官敬请谅解
class TrainingWidget(QWidget):"""模型训练界面"""def __init__(self):super().__init__()self.init_ui()self.model = Noneself.history = Nonedef init_ui(self):layout = QVBoxLayout()# 模型选择model_group = QGroupBox("模型配置")model_layout = QVBoxLayout()# 模型类型选择type_layout = QHBoxLayout()self.model_type = QComboBox()self.model_type.addItems(["CNN模型", "MobileNetV2迁移学习"])type_layout.addWidget(QLabel("模型类型:"))type_layout.addWidget(self.model_type)type_layout.addStretch()model_layout.addLayout(type_layout)# 训练参数设置param_layout = QGridLayout()self.epochs_spin = QSpinBox()self.epochs_spin.setRange(1, 100)self.epochs_spin.setValue(10)self.batch_size_spin = QSpinBox()self.batch_size_spin.setRange(1, 128)self.batch_size_spin.setValue(4)self.learning_rate = QDoubleSpinBox()self.learning_rate.setRange(0.0001, 0.1)self.learning_rate.setValue(0.001)self.learning_rate.setSingleStep(0.0001)self.learning_rate.setDecimals(4)param_layout.addWidget(QLabel("训练轮数:"), 0, 0)param_layout.addWidget(self.epochs_spin, 0, 1)param_layout.addWidget(QLabel("批次大小:"), 0, 2)param_layout.addWidget(self.batch_size_spin, 0, 3)param_layout.addWidget(QLabel("学习率:"), 0, 4)param_layout.addWidget(self.learning_rate, 0, 5)model_layout.addLayout(param_layout)# 数据集路径data_layout = QHBoxLayout()self.train_data_path = QLineEdit("./data/flower_photos")data_layout.addWidget(QLabel("训练数据:"))data_layout.addWidget(self.train_data_path)model_layout.addLayout(data_layout)model_group.setLayout(model_layout)layout.addWidget(model_group)# 训练控制control_group = QGroupBox("训练控制")control_layout = QHBoxLayout()self.train_btn = QPushButton("开始训练")self.train_btn.clicked.connect(self.start_training)self.stop_btn = QPushButton("停止训练")self.stop_btn.setEnabled(False)self.save_btn = QPushButton("保存模型")self.save_btn.clicked.connect(self.save_model)control_layout.addWidget(self.train_btn)control_layout.addWidget(self.stop_btn)control_layout.addWidget(self.save_btn)control_group.setLayout(control_layout)layout.addWidget(control_group)# 训练进度progress_group = QGroupBox("训练进度")progress_layout = QVBoxLayout()self.progress_bar = QProgressBar()self.status_label = QLabel("准备就绪")progress_layout.addWidget(self.progress_bar)progress_layout.addWidget(self.status_label)progress_group.setLayout(progress_layout)layout.addWidget(progress_group)# 训练曲线curve_group = QGroupBox("训练曲线")curve_layout = QVBoxLayout()self.figure = Figure(figsize=(10, 6))self.canvas = FigureCanvas(self.figure)curve_layout.addWidget(self.canvas)curve_group.setLayout(curve_layout)layout.addWidget(curve_group)# 训练日志log_group = QGroupBox("训练日志")log_layout = QVBoxLayout()self.log_text = QTextEdit()self.log_text.setReadOnly(True)self.log_text.setMaximumHeight(150)log_layout.addWidget(self.log_text)log_group.setLayout(log_layout)layout.addWidget(log_group)self.setLayout(layout)def start_training(self):self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 开始训练...")self.train_btn.setEnabled(False)self.stop_btn.setEnabled(True)# 加载数据data_dir = self.train_data_path.text()if not os.path.exists(data_dir):QMessageBox.warning(self, "错误", "训练数据路径不存在!")self.train_btn.setEnabled(True)self.stop_btn.setEnabled(False)returnbatch_size = self.batch_size_spin.value()epochs = self.epochs_spin.value()is_transfer = self.model_type.currentIndex() == 1try:# 加载数据train_ds, val_ds, class_names = self.data_load(data_dir, 224, 224, batch_size)# 加载模型self.model = self.model_load(is_transfer=is_transfer)# 训练模型self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 模型类型: {'MobileNetV2迁移学习' if is_transfer else 'CNN模型'}")self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练轮数: {epochs}, 批次大小: {batch_size}")# 创建回调来更新进度class ProgressCallback(tf.keras.callbacks.Callback):def __init__(self, widget, epochs):super().__init__()self.widget = widgetself.total_epochs = epochsdef on_epoch_end(self, epoch, logs=None):progress = int((epoch + 1) / self.total_epochs * 100)self.widget.progress_bar.setValue(progress)self.widget.status_label.setText(f"Epoch {epoch + 1}/{self.total_epochs} - Loss: {logs['loss']:.4f} - Accuracy: {logs['accuracy']:.4f}")self.widget.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] Epoch {epoch + 1}/{self.total_epochs} - Loss: {logs['loss']:.4f} - Acc: {logs['accuracy']:.4f}")QApplication.processEvents()callback = ProgressCallback(self, epochs)# 训练self.history = self.model.fit(train_ds,validation_data=val_ds,epochs=epochs,callbacks=[callback])# 显示训练曲线self.show_training_curves()self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练完成!")self.status_label.setText("训练完成")except Exception as e:self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 训练出错: {str(e)}")QMessageBox.critical(self, "错误", f"训练失败: {str(e)}")finally:self.train_btn.setEnabled(True)self.stop_btn.setEnabled(False)def data_load(self, data_dir, img_height, img_width, batch_size):"""加载数据"""train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,label_mode='categorical',validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,label_mode='categorical',validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_names = train_ds.class_namesreturn train_ds, val_ds, class_namesdef model_load(self, IMG_SHAPE=(224, 224, 3), is_transfer=False):"""加载模型"""if is_transfer:base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')base_model.trainable = Falsemodel = tf.keras.models.Sequential([tf.keras.layers.experimental.preprocessing.Rescaling(1. / 127.5, offset=-1, input_shape=IMG_SHAPE),base_model,tf.keras.layers.GlobalAveragePooling2D(),tf.keras.layers.Dense(5, activation='softmax')])else:model = tf.keras.models.Sequential([tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),tf.keras.layers.MaxPooling2D(2, 2),tf.keras.layers.Flatten(),tf.keras.layers.Dense(128, activation='relu'),tf.keras.layers.Dense(5, activation='softmax')])model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])return modeldef show_training_curves(self):"""显示训练曲线"""if self.history is None:returnself.figure.clear()# 准确率曲线ax1 = self.figure.add_subplot(1, 2, 1)ax1.plot(self.history.history['accuracy'], label='训练准确率')if 'val_accuracy' in self.history.history:ax1.plot(self.history.history['val_accuracy'], label='验证准确率')ax1.set_title('模型准确率')ax1.set_xlabel('Epoch')ax1.set_ylabel('准确率')ax1.legend()ax1.grid(True)# 损失曲线ax2 = self.figure.add_subplot(1, 2, 2)ax2.plot(self.history.history['loss'], label='训练损失')if 'val_loss' in self.history.history:ax2.plot(self.history.history['val_loss'], label='验证损失')ax2.set_title('模型损失')ax2.set_xlabel('Epoch')ax2.set_ylabel('损失')ax2.legend()ax2.grid(True)self.figure.tight_layout()self.canvas.draw()def save_model(self):if self.model is None:QMessageBox.warning(self, "警告", "没有可保存的模型!")returnfile_path, _ = QFileDialog.getSaveFileName(self, "保存模型", "./models", "H5 Files (*.h5)")if file_path:try:self.model.save(file_path)self.log_text.append(f"[{datetime.now().strftime('%H:%M:%S')}] 模型已保存到: {file_path}")QMessageBox.information(self, "成功", "模型保存成功!")except Exception as e:QMessageBox.critical(self, "错误", f"模型保存失败: {str(e)}")
测试时输出的log:
测试结果:
损失值: 0.3617
准确率: 0.8965 (89.65%)
测试样本数: 734
各类别准确率:
daisy: 93.80% (正确: 121/129)
dandelion: 96.02% (正确: 169/176)
roses: 74.17% (正确: 89/120)
sunflowers: 88.82% (正确: 135/152)
tulips: 91.72% (正确: 144/157)
模型架构摘要:
总参数量: 2264389
层数: 4