使用 Python 进行图片识别的项目开发
使用 Python 进行图片识别的项目开发流程
图片识别(Image Recognition)是人工智能与计算机视觉领域的重要应用,广泛应用于人脸识别、物体检测、图像分类、智能安防、医疗影像分析等场景。Python 以其强大的生态库和简洁语法,成为实现图片识别的首选语言之一。
本文将带你从零开始,构建一个完整的 “图片分类识别系统”,能够识别输入图片中的物体类别(如猫、狗、汽车、人等),并输出结果。我们将使用深度学习模型(预训练卷积神经网络)来实现高精度识别,同时兼顾性能与可扩展性。
一、项目目标
我们希望构建一个基于 Python 的图片识别系统,具备以下功能:
- 能够加载本地图片文件;
- 使用预训练模型(如 ResNet)进行图像分类;
- 输出识别结果(类别标签 + 置信度);
- 支持批量处理多张图片;
- 可视化识别结果(绘制边框与标签);
- 项目结构清晰,便于后续扩展为 Web 或移动端应用。
最终效果示例:
识别结果:cat(置信度:92.3%)
识别结果:dog(置信度:87.1%)
二、技术选型与核心库介绍
我们将使用以下 Python 第三方库:
库名 | 作用 |
---|---|
torch + torchvision | PyTorch 深度学习框架及其视觉模型库,提供 ResNet、AlexNet 等预训练模型 |
Pillow (PIL ) | 图像加载与基本处理 |
matplotlib / cv2 | 图像显示与可视化 |
numpy | 数值计算与张量操作 |
os , glob | 文件路径操作与批量读取 |
✅ 安装依赖
在终端中运行以下命令安装所需库:
pip install torch torchvision pillow matplotlib opencv-python numpy
⚠️ 注意:PyTorch 安装可能因系统不同而异,建议访问 https://pytorch.org 获取对应系统的安装命令。
三、项目结构设计
建议创建如下项目目录结构:
image_recognition_project/
│
├── input_images/ # 存放待识别的图片
│ ├── cat.jpg
│ ├── dog.jpg
│ └── car.jpg
│
├── output_results/ # 保存识别结果图像
│
├── models/ # 可选:保存下载的模型权重(自动缓存也可)
│
├── labels/imagenet_classes.txt # ImageNet 1000 类标签文件
│
├── image_classifier.py # 主识别程序
│
└── utils.py # 工具函数模块
四、步骤详解与代码实现
第一步:准备 ImageNet 类别标签
我们使用的预训练模型是在 ImageNet 数据集上训练的,包含 1000 个类别。我们需要一个文本文件来映射类别索引到人类可读的标签。
你可以从以下链接下载标准的 imagenet_classes.txt
文件:
👉 https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
将其保存到 labels/imagenet_classes.txt
。
第二步:编写工具函数模块 utils.py
# -*- coding: utf-8 -*-
"""
utils.py
功能:提供图像预处理、标签加载、结果可视化等通用工具函数
"""import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import osdef load_imagenet_labels(label_file="labels/imagenet_classes.txt"):"""加载 ImageNet 1000 个类别的标签返回:列表,索引对应类别编号"""if not os.path.exists(label_file):raise FileNotFoundError(f"未找到标签文件:{label_file}")with open(label_file, "r") as f:categories = [line.strip() for line in f.readlines()]return categoriesdef preprocess_image(image_path, image_size=224):"""对输入图像进行预处理,符合模型输入要求输入:图像路径输出:处理后的 tensor(batch_size=1)"""# 定义预处理变换preprocess = transforms.Compose([transforms.Resize(image_size), # 缩放transforms.CenterCrop(image_size), # 中心裁剪transforms.ToTensor(), # 转为 Tensortransforms.Normalize( # 归一化(ImageNet 均值与标准差)mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),])# 加载图像image = Image.open(image_path).convert("RGB")# 应用变换image_tensor = preprocess(image)# 增加 batch 维度 (1, C, H, W)image_tensor = image_tensor.unsqueeze(0)return image_tensordef display_image_with_prediction(image_path, predicted_label, confidence):"""显示原始图像,并在窗口标题中显示预测结果"""image = Image.open(image_path)plt.figure(figsize=(6, 6))plt.imshow(image)plt.title(f"预测: {predicted_label} (置信度: {confidence:.1f}%)", fontsize=14, color='green')plt.axis("off")plt.show()def get_top_predictions(output, categories, top_k=5):"""获取 top-k 个预测结果返回:列表 [(标签, 概率), ...]"""# 计算概率probabilities = torch.nn.functional.softmax(output[0], dim=0)top_probs, top_indices = torch.topk(probabilities, top_k)results = []for i in range(top_k):idx = top_indices[i].item()prob = top_probs[i].item() * 100 # 转为百分比label = categories[idx]results.append((label, prob))return results
📌 代码说明:
load_imagenet_labels()
:读取类别标签;preprocess_image()
:将图像缩放、裁剪、归一化,符合 ResNet 输入格式;display_image_with_prediction()
:使用matplotlib
显示图像与结果;get_top_predictions()
:返回前 k 个最可能的类别及其置信度。
第三步:主程序 image_classifier.py
# -*- coding: utf-8 -*-
"""
image_classifier.py
主程序:使用预训练 ResNet 模型进行图像识别
"""import torch
import torchvision.models as models
from PIL import Image
import os
import glob
from utils import (load_imagenet_labels,preprocess_image,display_image_with_prediction,get_top_predictions
)# 配置参数
INPUT_DIR = "input_images"
OUTPUT_DIR = "output_results"
LABEL_FILE = "labels/imagenet_classes.txt"
MODEL_NAME = "resnet50" # 可选: "resnet18", "resnet34", "mobilenet_v2" 等
TOP_K = 5
CONFIDENCE_THRESHOLD = 50.0 # 最低置信度阈值# 创建输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)def load_model(model_name="resnet50", pretrained=True):"""加载预训练图像分类模型"""print(f"📥 正在加载预训练模型 {model_name}...")model = getattr(models, model_name)(pretrained=pretrained)model.eval() # 设置为评估模式print(f"✅ 模型 {model_name} 加载成功!")return modeldef classify_image(model, image_path, categories):"""对单张图像进行分类识别"""print(f"\n🔍 正在识别图像:{image_path}")try:# 预处理图像input_tensor = preprocess_image(image_path)# 推理(无需梯度)with torch.no_grad():output = model(input_tensor)# 获取 top-k 预测结果top_results = get_top_predictions(output, categories, TOP_K)# 获取最高置信度的结果best_label, best_confidence = top_results[0]# 判断是否超过阈值if best_confidence < CONFIDENCE_THRESHOLD:print(f"⚠️ 识别置信度较低:{best_confidence:.1f}%,结果可能不可靠")else:print(f"🎉 识别成功:{best_label} (置信度: {best_confidence:.1f}%)")# 打印 top-5 结果print("📊 Top-5 预测结果:")for i, (label, prob) in enumerate(top_results, 1):mark = " ← 最佳匹配" if i == 1 else ""print(f" {i}. {label} ({prob:.1f}%) {mark}")# 显示图像与结果display_image_with_prediction(image_path, best_label, best_confidence)return best_label, best_confidenceexcept Exception as e:print(f"❌ 识别失败:{str(e)}")return "Error", 0.0def batch_classify(model, image_dir, categories):"""批量识别指定目录下的所有图片"""supported_formats = ("*.jpg", "*.jpeg", "*.png", "*.bmp")image_paths = []for ext in supported_formats:image_paths.extend(glob.glob(os.path.join(image_dir, ext)))if len(image_paths) == 0:print(f"❌ 在 {image_dir} 中未找到支持的图片文件")returnprint(f"📦 共发现 {len(image_paths)} 张图片,开始批量识别...")results = []for image_path in image_paths:label, conf = classify_image(model, image_path, categories)results.append({"filename": os.path.basename(image_path),"predicted_label": label,"confidence": conf})return resultsdef save_results_to_csv(results, output_file="recognition_results.csv"):"""将识别结果保存为 CSV 文件"""import csvwith open(output_file, 'w', encoding='utf-8', newline='') as f:writer = csv.DictWriter(f, fieldnames=["filename", "predicted_label", "confidence"])writer.writeheader()writer.writerows(results)print(f"💾 识别结果已保存至:{output_file}")def main():"""主函数"""# 1. 加载类别标签try:categories = load_imagenet_labels(LABEL_FILE)except Exception as e:print(f"❌ 标签加载失败:{e}")return# 2. 加载模型try:model = load_model(MODEL_NAME)except Exception as e:print(f"❌ 模型加载失败:{e}")return# 3. 执行批量识别results = batch_classify(model, INPUT_DIR, categories)# 4. 保存结果if results:save_results_to_csv(results)print("\n🔚 图片识别任务完成!")# 启动程序
if __name__ == "__main__":main()
📌 代码详细说明:
1. 模型加载
model = getattr(models, model_name)(pretrained=True)
动态加载 torchvision.models
中的预训练模型(如 ResNet50),无需自己训练。
2. 推理模式
model.eval()
with torch.no_grad():output = model(input_tensor)
关闭梯度计算以提升推理速度和内存效率。
3. Softmax 概率转换
probabilities = torch.nn.functional.softmax(output[0], dim=0)
将模型输出的 logits 转换为概率分布(0~1 之间)。
4. Top-k 预测
使用 torch.topk()
获取前 k 个最可能的类别。
5. 批量处理
通过 glob
模块遍历目录中所有图片,实现批量识别。
6. 结果持久化
将识别结果导出为 CSV 文件,便于后续分析。
五、完整项目运行流程
步骤 1:准备环境与文件
- 创建项目文件夹;
- 安装依赖库;
- 下载
imagenet_classes.txt
并放入labels/
目录; - 在
input_images/
中放入测试图片(如猫、狗、车、飞机等);
步骤 2:运行程序
python image_classifier.py
输出示例:
📥 正在加载预训练模型 resnet50...
✅ 模型 resnet50 加载成功!🔍 正在识别图像:input_images/cat.jpg
🎉 识别成功:Egyptian cat (置信度: 93.2%)
📊 Top-5 预测结果:1. Egyptian cat (93.2%) ← 最佳匹配2. tabby (4.1%)3. tiger cat (1.8%)...📊 图像将弹出显示...
💾 识别结果已保存至:recognition_results.csv🔚 图片识别任务完成!
同时会弹出图像窗口,显示识别结果。
六、常见问题与调试技巧
❌ 问题1:缺少 imagenet_classes.txt
错误提示:
FileNotFoundError: [Errno 2] No such file or directory: 'labels/imagenet_classes.txt'
解决方法:
手动创建 labels/
文件夹,并下载标签文件:
mkdir labels
curl -o labels/imagenet_classes.txt https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt
❌ 问题2:CUDA out of memory(GPU 内存不足)
原因: 模型较大(如 ResNet50),GPU 显存不够。
解决方案:
- 使用更轻量模型:
MODEL_NAME = "mobilenet_v2"
; - 在
load_model
中添加.cpu()
强制使用 CPU:model = load_model("resnet50").cpu()
- 减少批量大小(当前为单图识别,影响较小)。
❌ 问题3:图像格式不支持或损坏
建议处理方式:
在 preprocess_image
中增加异常捕获:
try:image = Image.open(image_path).convert("RGB")
except Exception as e:print(f"无法读取图像 {image_path}: {e}")return None
❌ 问题4:识别结果不准
可能原因:
- 图像模糊、角度偏斜;
- 物体太小或被遮挡;
- 模型未见过该类别(ImageNet 未覆盖所有现实场景)。
优化建议:
- 使用更高分辨率图像;
- 尝试其他模型(如 EfficientNet、ViT);
- 微调(Fine-tune)模型以适应特定任务。
七、代码优化与最佳实践
✅ 1. 使用类封装提升可维护性
class ImageClassifier:def __init__(self, model_name="resnet50", label_file="labels/imagenet_classes.txt"):self.model = self.load_model(model_name)self.categories = self.load_labels(label_file)def predict(self, image_path):# 封装识别逻辑pass
便于后期扩展为 API 服务或 GUI 应用。
✅ 2. 支持命令行参数
使用 argparse
支持自定义参数:
import argparseparser = argparse.ArgumentParser()
parser.add_argument("--input", default="input_images", help="输入图片目录")
parser.add_argument("--model", default="resnet50", help="模型名称")
args = parser.parse_args()
✅ 3. 日志记录替代 print
使用 logging
模块记录运行日志,便于生产环境排查问题。
import logging
logging.basicConfig(level=logging.INFO)
logging.info("模型加载完成")
八、扩展功能建议
🔹 功能1:Web 接口(Flask)
from flask import Flask, request, jsonify
app = Flask(__name__)
classifier = ImageClassifier()@app.route("/predict", methods=["POST"])
def predict():file = request.files["image"]result = classifier.predict(file.stream)return jsonify(result)
启动后可通过 POST 请求上传图片进行识别。
🔹 功能2:实时摄像头识别
结合 OpenCV 实现视频流识别:
cap = cv2.VideoCapture(0)
while True:ret, frame = cap.read()# 将 frame 转为 PIL.Image 进行识别label, conf = classify_frame(frame)
🔹 功能3:模型微调(Fine-tuning)
如果你有自己的数据集(如公司 Logo、特定产品),可以对预训练模型进行微调,提升特定任务准确率。
model.fc = torch.nn.Linear(2048, num_custom_classes) # 修改最后分类层
# 使用你的数据进行训练