transformers microsoft--table-transformer 表格识别
一、安装包
pip install transformers
pip install torch
pip install SentencePiecepip install timm
pip install accelerate
pip install pytesseract pillow pandas
pip install tesseract
下载模型:
https://huggingface.co/microsoft/table-transformer-structure-recognition/tree/main
https://huggingface.co/microsoft/table-transformer-detection/tree/main
二、安装tesseract-ocr
我这里用的windows
下载:tesseract-ocr-w64-setup-5.4.0.20240606.exe 安装
https://tesseract-ocr.github.io/tessdoc/Downloads.html
https://digi.bib.uni-mannheim.de/tesseract/ 【tesseract-ocr-w64-setup-5.4.0.20240606.exe】
添加环境变量:
三、准备图片
下载:https://download.csdn.net/download/xiaoxionglove/90063200
四、编写代码
from PIL import Image
from transformers import DetrImageProcessor
from transformers import TableTransformerForObjectDetectionimport torch
import matplotlib.pyplot as plt
import os
import psutil
import time
from transformers import DetrFeatureExtractor
feature_extractor = DetrFeatureExtractor()
import pandas as pdimport pytesseractmodel = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-detection")COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]def plot_results(pil_img, scores, labels, boxes):plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for score, label, (xmin, ymin, xmax, ymax),c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))text = f'{model.config.id2label[label]}: {score:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()def table_detection(file_path):image = Image.open(file_path).convert("RGB")width, height = image.sizeimage.resize((int(width *0.5), int(height *0.5)))feature_extractor = DetrImageProcessor()encoding = feature_extractor(image, return_tensors="pt")with torch.no_grad():outputs = model(**encoding)width, height = image.sizeresults = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]plot_results(image, results['scores'], results['labels'], results['boxes'])return results['boxes']ram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"ram usage : {ram_usage}")count = 0
root = "Detection_Images_Test/"for file in os.listdir(root):file_path = os.path.join(root, file)start_time = time.time()pred_bbox = table_detection(file_path)count += 1end_time = time.time()time_usage = end_time - start_timeram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"Iteration {count + 1} - RAM Usage: {ram_usage:.2f} MB, Time Usage: {time_usage:.2f} seconds")if count > 2:breakfile = 'img_test/PMC1064078_table_0.jpg.png'
image = Image.open(file).convert("RGB")
imagefrom huggingface_hub import hf_hub_download
from PIL import Imagefrom transformers import TableTransformerForObjectDetectionmodel = TableTransformerForObjectDetection.from_pretrained("microsoft/table-transformer-structure-recognition")def cell_detection(file_path):image = Image.open(file_path).convert("RGB")width, height = image.sizeimage.resize((int(width*0.5), int(height*0.5)))encoding = feature_extractor(image, return_tensors="pt")encoding.keys()with torch.no_grad():outputs = model(**encoding)target_sizes = [image.size[::-1]]results = feature_extractor.post_process_object_detection(outputs, threshold=0.6, target_sizes=target_sizes)[0]plot_results(image, results['scores'], results['labels'], results['boxes'])model.config.id2labelram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"ram usage : {ram_usage}")count = 0
root = "img_test/"
for file in os.listdir(root):file_path = os.path.join(root, file)start_time = time.time()cell_detection(file_path)count += 1end_time = time.time()time_usage = end_time - start_timeram_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024print(f"Iteration {count + 1} - RAM Usage: {ram_usage:.2f} MB, Time Usage: {time_usage:.2f} seconds")if (count > 2):breakdef plot_results_specific(pil_img, scores, labels, boxes,lab):plt.figure(figsize=(16, 10))plt.imshow(pil_img)ax = plt.gca()colors = COLORS * 100for score, label, (xmin, ymin, xmax, ymax), c in zip(scores.tolist(), labels.tolist(), boxes.tolist(), colors):if label == lab:ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))text = f'{model.config.id2label[label]}: {score:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.show()def draw_box_specific(image_path,labelnum):image = Image.open(image_path).convert("RGB")width, height = image.sizeencoding = feature_extractor(image, return_tensors="pt")with torch.no_grad():outputs = model(**encoding)results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]plot_results_specific(image, results['scores'], results['labels'], results['boxes'],labelnum)def compute_boxes(image_path):image = Image.open(image_path).convert("RGB")width, height = image.sizeencoding = feature_extractor(image, return_tensors="pt")with torch.no_grad():outputs = model(**encoding)results = feature_extractor.post_process_object_detection(outputs, threshold=0.7, target_sizes=[(height, width)])[0]boxes = results['boxes'].tolist()labels = results['labels'].tolist()return boxes,labelsdef extract_table(image_path):image = Image.open(image_path).convert("RGB")boxes, labels = compute_boxes(image_path)cell_locations = []for box_row, label_row in zip(boxes, labels):if label_row == 2:for box_col, label_col in zip(boxes, labels):if label_col == 1:cell_box = (box_col[0], box_row[1], box_col[2], box_row[3])cell_locations.append(cell_box)cell_locations.sort(key=lambda x: (x[1], x[0]))num_columns = 0box_old = cell_locations[0]for box in cell_locations[1:]:x1, y1, x2, y2 = boxx1_old, y1_old, x2_old, y2_old = box_oldnum_columns += 1if y1 > y1_old:breakbox_old = boxheaders = []for box in cell_locations[:num_columns]:x1, y1, x2, y2 = boxcell_image = image.crop((x1, y1, x2, y2))new_width = cell_image.width * 4new_height = cell_image.height * 4cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)cell_text = pytesseract.image_to_string(cell_image)headers.append(cell_text.rstrip())df = pd.DataFrame(columns=headers)row = []for box in cell_locations[num_columns:]:x1, y1, x2, y2 = boxcell_image = image.crop((x1, y1, x2, y2))new_width = cell_image.width * 4new_height = cell_image.height * 4cell_image = cell_image.resize((new_width, new_height), resample=Image.LANCZOS)cell_text = pytesseract.image_to_string(cell_image)if len(cell_text) > num_columns:cell_text = cell_text[:num_columns]row.append(cell_text.rstrip())if len(row) == num_columns:df.loc[len(df)] = rowrow = []return dfimage_path = 'img_test/PMC1112589_table_0.jpg'
draw_box_specific(image_path,1)
df = extract_table(image_path)
df.to_csv('data.csv', index=False)
我们将图片中的表格识别并存到csv中