当前位置: 首页 > news >正文

transformers microsoft--table-transformer 表格识别

一、安装包

pip install transformers
pip install torch
pip install SentencePiecepip install timm 
pip install accelerate
pip install pytesseract pillow pandas
pip install tesseract

下载模型:

https://huggingface.co/microsoft/table-transformer-structure-recognition/tree/main

https://huggingface.co/microsoft/table-transformer-detection/tree/main

二、安装tesseract-ocr

我这里用的windows

下载:tesseract-ocr-w64-setup-5.4.0.20240606.exe 安装

https://tesseract-ocr.github.io/tessdoc/Downloads.html
https://digi.bib.uni-mannheim.de/tesseract/  【tesseract-ocr-w64-setup-5.4.0.20240606.exe】

添加环境变量:

三、准备图片

下载:https://download.csdn.net/download/xiaoxionglove/90063200

四、编写代码

from PIL import Image
from transformers import DetrImageProcessor
from transformers import TableTransformerForObjectDetectionimport torch
import matplotlib.pyplot as plt
import os
import psutil
import time
from transformers import DetrFeatureExtractor
feature_extractor = DetrFeatureExtractor()
import pandas as pdimport pytesseractmodel = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]def plot_results(pil_img, scores, labels, boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for score, label, (xmin, ymin, xmax, ymax),c  in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))text = f'{model.config.id2label[label]}: {score:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()def table_detection(file_path):image = Image.open(file_path).convert("RGB")width, height = image.sizeimage.resize((int(width *0.5), int(height *0.5)))feature_extractor = DetrImageProcessor()encoding = feature_extractor(image, return_tensors="pt")with torch.no_grad():outputs = model(**encoding)width, height = image.sizeresults = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]plot_results(image, results['scores'], results['labels'], results['boxes'])return results['boxes']ram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"ram usage : {ram_usage}")count = 0
root = "Detection_Images_Test/"for file in os.listdir(root):file_path = os.path.join(root, file)start_time = time.time()pred_bbox = table_detection(file_path)count += 1end_time = time.time()time_usage = end_time - start_timeram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"Iteration {count + 1} - RAM Usage: {ram_usage:.2f} MB, Time Usage: {time_usage:.2f} seconds")if count > 2:breakfile = 'img_test/PMC1064078_table_0.jpg.png'
image = Image.open(file).convert("RGB")
imagefrom huggingface_hub import hf_hub_download
from PIL import Imagefrom transformers import TableTransformerForObjectDetectionmodel = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")def cell_detection(file_path):image = Image.open(file_path).convert("RGB")width, height = image.sizeimage.resize((int(width*0.5), int(height*0.5)))encoding = feature_extractor(image, return_tensors="pt")encoding.keys()with torch.no_grad():outputs = model(**encoding)target_sizes = [image.size[::-1]]results = feature_extractor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]plot_results(image, results['scores'], results['labels'], results['boxes'])model.config.id2labelram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"ram usage : {ram_usage}")count = 0
root = "img_test/"
for file in os.listdir(root):file_path = os.path.join(root, file)start_time = time.time()cell_detection(file_path)count += 1end_time = time.time()time_usage = end_time - start_timeram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"Iteration {count + 1} - RAM Usage: {ram_usage:.2f} MB, Time Usage: {time_usage:.2f} seconds")if (count > 2):breakdef plot_results_specific(pil_img, scores, labels, boxes,lab):plt.figure(figsize=(16, 10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):if label == lab:ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))text = f'{model.config.id2label[label]}: {score:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()def draw_box_specific(image_path,labelnum):image = Image.open(image_path).convert("RGB")width, height = image.sizeencoding = feature_extractor(image, return_tensors="pt")with torch.no_grad():outputs = model(**encoding)results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]plot_results_specific(image, results['scores'], results['labels'], results['boxes'],labelnum)def compute_boxes(image_path):image = Image.open(image_path).convert("RGB")width, height = image.sizeencoding = feature_extractor(image, return_tensors="pt")with torch.no_grad():outputs = model(**encoding)results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]boxes = results['boxes'].tolist()labels = results['labels'].tolist()return boxes,labelsdef extract_table(image_path):image = Image.open(image_path).convert("RGB")boxes, labels = compute_boxes(image_path)cell_locations = []for box_row, label_row in zip(boxes, labels):if label_row == 2:for box_col, label_col in zip(boxes, labels):if label_col == 1:cell_box = (box_col[0], box_row[1], box_col[2], box_row[3])cell_locations.append(cell_box)cell_locations.sort(key=lambda x: (x[1], x[0]))num_columns = 0box_old = cell_locations[0]for box in cell_locations[1:]:x1, y1, x2, y2 = boxx1_old, y1_old, x2_old, y2_old = box_oldnum_columns += 1if y1 > y1_old:breakbox_old = boxheaders = []for box in cell_locations[:num_columns]:x1, y1, x2, y2 = boxcell_image = image.crop((x1, y1, x2, y2))new_width = cell_image.width * 4new_height = cell_image.height * 4cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)cell_text = pytesseract.image_to_string(cell_image)headers.append(cell_text.rstrip())df = pd.DataFrame(columns=headers)row = []for box in cell_locations[num_columns:]:x1, y1, x2, y2 = boxcell_image = image.crop((x1, y1, x2, y2))new_width = cell_image.width * 4new_height = cell_image.height * 4cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)cell_text = pytesseract.image_to_string(cell_image)if len(cell_text) > num_columns:cell_text = cell_text[:num_columns]row.append(cell_text.rstrip())if len(row) == num_columns:df.loc[len(df)] = rowrow = []return dfimage_path = 'img_test/PMC1112589_table_0.jpg'
draw_box_specific(image_path,1)
df = extract_table(image_path)
df.to_csv('data.csv', index=False)

我们将图片中的表格识别并存到csv中

http://www.lryc.cn/news/494671.html

相关文章:

  • 【Spark源码分析】规则框架-草稿
  • 迪米特原则的理解和实践
  • jQuery零基础入门速通(中)
  • 【设计模式系列】中介者模式(十八)
  • PDF版地形图矢量出现的问题
  • 小迪安全第四十二天笔记 简单的mysql注入 mysql的基础知识 用户管理数据库模式 mysql 写入与读取 跨库查询
  • 11.25.2024刷华为OD
  • 你真的会用饼图吗?JVS-智能BI饼图组件深度解析
  • HarmonyOS Next 模拟器安装与探索
  • 医学机器学习:数据预处理、超参数调优与模型比较的实用分析
  • 单片机知识总结(完整)
  • 【C++】auto和decltype类型推导关键字
  • OGRE 3D----3. OGRE绘制自定义模型
  • ARM + Linux 开发指南
  • facebook欧洲户开户条件有哪些又有何优势?
  • 算法训练(leetcode)二刷第三十一天 | 1049. 最后一块石头的重量 II、494. 目标和、*474. 一和零
  • 软件测试丨Pytest生命周期与数据驱动
  • Figma入门-原型交互
  • 网络安全防范技术
  • Java - JSR223规范解读_在JVM上实现多语言支持
  • win10系统部署RAGFLOW+Ollama教程
  • 基于Python制作一个简易UI界面
  • 鲁菜大师程伟华到访金宫川派味业
  • Linux设置jar包开机自启动
  • IoTDB 常见问题 QA 第一期
  • 【linux学习指南】linux捕捉信号
  • git如何快速拉取已经提交的mr进行验证
  • 【阿来来gis规划师工具箱说明书】h07四分标注
  • 【大数据学习 | 面经】HDFS的三副本机制和编码机制
  • lua-cjson 例子