diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml index 99e137c..c2f6b44 100644 --- a/.github/workflows/publish_whl.yml +++ b/.github/workflows/publish_whl.yml @@ -6,7 +6,7 @@ on: - v* env: - RESOURCES_URL: https://github.com/RapidAI/RapidTable/releases/download/assets/rapid_table_models.zip + DEFAULT_MODEL: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/slanet-plus.onnx jobs: UnitTesting: @@ -26,16 +26,15 @@ jobs: - name: Unit testings run: | - wget $RESOURCES_URL - ZIP_NAME=${RESOURCES_URL##*/} - DIR_NAME=${ZIP_NAME%.*} - unzip $DIR_NAME - cp $DIR_NAME/*.onnx rapid_table/models/ + wget $DEFAULT_MODEL -P rapid_table/models pip install -r requirements.txt pip install rapidocr_onnxruntime + pip install torch + pip install torchvision + pip install tokenizers pip install pytest - pytest tests/test_table.py + pytest tests/test_main.py GenerateWHL_PushPyPi: needs: UnitTesting @@ -55,11 +54,8 @@ jobs: pip install -r requirements.txt python -m pip install --upgrade pip pip install wheel get_pypi_latest_version - wget $RESOURCES_URL - ZIP_NAME=${RESOURCES_URL##*/} - DIR_NAME=${ZIP_NAME%.*} - unzip $ZIP_NAME - mv $DIR_NAME/slanet-plus.onnx rapid_table/models/ + + wget $DEFAULT_MODEL -P rapid_table/models python setup.py bdist_wheel ${{ github.ref_name }} - name: Publish distribution 📦 to PyPI diff --git a/.gitignore b/.gitignore index edaddfb..eee50b9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +outputs/ +*.json + # Created by .ignore support plugin (hsz.mobi) ### Python template # Byte-compiled / optimized / DLL files @@ -156,3 +159,6 @@ long1.jpg *.pdmodel .DS_Store +*.pth +/rapid_table_torch/models/*.pth +/rapid_table_torch/models/*.json diff --git a/LICENSE b/LICENSE index a13bf2b..16c91c5 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2024 RapidAI + Copyright 2025 RapidAI Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 6a62216..cafcd4a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,8 @@

📊 Rapid Table

- + + PyPI @@ -15,100 +16,90 @@ ### 简介 -RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOCR,将给定图像中的表格转化对应的HTML格式。 - -目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格: +RapidTable库是专门用来文档类图像的表格结构还原,表格结构模型均属于序列预测方法,结合RapidOCR,将给定图像中的表格转化对应的HTML格式。 slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升 - | 模型类型 | 模型名称 | 模型大小 | - |:--------------:|:--------------------------------------:| :------: | - | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M | - | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M | - | slanet_plus 中文 | `slanet-plus.onnx` | 6.8M | - +unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable) -模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md) -[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md) +### 最近动态 -模型下载地址为:[link](https://github.com/RapidAI/RapidTable/releases/tag/assets) +2025-01-09 update: 发布v1.0.2,全新接口升级。 \ +2024.12.30 update:支持Unitable模型的表格识别,使用pytorch框架 \ +2024.11.24 update:支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 \ +2024.10.13 update:补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) ### 效果展示
- Demo + Demo
-### 更新日志 - -
- -#### 2024.11.24 update -- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 - -#### 2024.10.13 update -- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) - -#### 2023-12-29 v0.1.3 update - -- 优化可视化结果部分 - -#### 2023-12-27 v0.1.2 update - -- 添加返回cell坐标框参数 -- 完善可视化函数 - -#### 2023-07-17 v0.1.0 update - -- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 +### 模型列表 -- 增加接口输入参数`ocr_result`: - - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 - - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 - - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 +| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)| + |:--------------|:--------------------------------------| :------: |:------ |:------ | +| `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s | +| `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s | +| `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s | +| `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)| -#### 2023-07-10 v0.0.13 updata - -- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 +模型来源\ +[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\ +[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\ +[Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file) -#### 2023-07-06 v0.0.12 update - -- 去掉返回表格的html字符串中的``元素,便于后续统一。 -- 采用Black工具优化代码 - -
- - -### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系 - -TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。 - -RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`。 - -总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。 - -关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果) +模型下载地址:[link](https://www.modelscope.cn/models/RapidAI/RapidTable/files) ### 安装 -由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。 +由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径。注意仅限于我们现支持的`model_type`。 > ⚠️注意:`rapid_table>=v0.1.0`之后,不再将`rapidocr_onnxruntime`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。 ```bash pip install rapidocr_onnxruntime pip install rapid_table -#pip install onnxruntime-gpu # for gpu inference + +# 基于torch来推理unitable模型 +pip install rapid_table[torch] # for unitable inference + +# onnxruntime-gpu推理 +pip uninstall onnxruntime +pip install onnxruntime-gpu # for onnx gpu inference ``` ### 使用方式 #### python脚本运行 -RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`slanet-plus.onnx`。举例如下: +> ⚠️注意:在`rapid_table>=1.0.0`之后,模型输入均采用dataclasses封装,简化和兼容参数传递。输入和输出定义如下: ```python -table_engine = RapidTable() +# 输入 +@dataclass +class RapidTableInput: + model_type: Optional[str] = ModelType.SLANETPLUS.value + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + +# 输出 +@dataclass +class RapidTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None + +# 使用示例 +input_args = RapidTableInput(model_type="unitable") +table_engine = RapidTable(input_args) + +img_path = 'test_images/table.jpg' +table_results = table_engine(img_path) + +print(table_results.pred_html) ``` 完整示例: @@ -120,20 +111,28 @@ from rapid_table import RapidTable, VisTable from rapidocr_onnxruntime import RapidOCR from rapid_table.table_structure.utils import trans_char_ocr_res - +# 默认是slanet_plus模型 table_engine = RapidTable() + # 开启onnx-gpu推理 -# table_engine = RapidTable(use_cuda=True) +# input_args = RapidTableInput(use_cuda=True) +# table_engine = RapidTable(input_args) + +# 使用torch推理版本的unitable模型 +# input_args = RapidTableInput(model_type="unitable", use_cuda=True, device="cuda:0") +# table_engine = RapidTable(input_args) + ocr_engine = RapidOCR() viser = VisTable() img_path = 'test_images/table.jpg' - ocr_result, _ = ocr_engine(img_path) + # 单字匹配 # ocr_result, _ = ocr_engine(img_path, return_word_box=True) # ocr_result = trans_char_ocr_res(ocr_result) -table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result) + +table_results = table_engine(img_path, ocr_result) save_dir = Path("./inference_results/") save_dir.mkdir(parents=True, exist_ok=True) @@ -141,50 +140,217 @@ save_dir.mkdir(parents=True, exist_ok=True) save_html_path = save_dir / f"{Path(img_path).stem}.html" save_drawed_path = save_dir / f"vis_{Path(img_path).name}" -viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path) - -# 返回逻辑坐标 -# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True) -# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}" -# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path) - +viser( + img_path, + table_results.pred_html, + save_html_path, + table_results.cell_bboxes, + save_drawed_path, + table_results.logic_points, + save_logic_path, +) print(table_html_str) ``` #### 终端运行 -- 用法: - - ```bash - $ rapid_table -h - usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH] - - optional arguments: - -h, --help show this help message and exit - -v, --vis Whether to visualize the layout results. - -img IMG_PATH, --img_path IMG_PATH - Path to image for layout. - -m MODEL_PATH, --model_path MODEL_PATH - The model path used for inference. - ``` - -- 示例: - - ```bash - rapid_table -v -img test_images/table.jpg - ``` +```bash +rapid_table -v -img test_images/table.jpg +``` ### 结果 #### 返回结果 +
+ ```html -<>
MethodsFPS
SegLink [26]70.086d>77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.688582.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + <> + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
MethodsFPS
SegLink [26]70.086d> + + 77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.2 + 91.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.6885 + + 82.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+ + ``` +
+ #### 可视化结果
<>
MethodsFPS
SegLink [26]70.086d>77.08.9
PixelLink [4]73.283.077.8
TextSnake [18]73.983.278.31.1
TextField [37]75.987.481.35.2
MSR[38]76.787.87.481.7
FTSN [3]77.187.682.0
LSE[30]81.784.282.9
CRAFT [2]78.288.282.98.6
MCN[16]798883
ATRR[35]82.185.283.6
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG[41]82.3088.0585.08
Ours (SynText)80.688582.9712.68
Ours (MLT-17)84.5486.6285.5712.31
+ +### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系 + +TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。 + +RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`。 + +总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。 + +关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果) + +### 更新日志 + +
+ +#### 2024.12.30 update + +- 支持Unitable模型的表格识别,使用pytorch框架 + +#### 2024.11.24 update + +- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 + +#### 2024.10.13 update + +- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx) + +#### 2023-12-29 v0.1.3 update + +- 优化可视化结果部分 + +#### 2023-12-27 v0.1.2 update + +- 添加返回cell坐标框参数 +- 完善可视化函数 + +#### 2023-07-17 v0.1.0 update + +- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。 + +- 增加接口输入参数`ocr_result`: + - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。 + - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。 + - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。 + +#### 2023-07-10 v0.0.13 updata + +- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致 + +#### 2023-07-06 v0.0.12 update + +- 去掉返回表格的html字符串中的``元素,便于后续统一。 +- 采用Black工具优化代码 + +
diff --git a/demo.py b/demo.py index fb13d21..d9e8ba3 100644 --- a/demo.py +++ b/demo.py @@ -6,12 +6,14 @@ import cv2 from rapidocr_onnxruntime import RapidOCR, VisRes -from rapid_table import RapidTable, VisTable +from rapid_table import RapidTable, RapidTableInput, VisTable # Init ocr_engine = RapidOCR() vis_ocr = VisRes() -table_engine = RapidTable() + +input_args = RapidTableInput(model_type="unitable") +table_engine = RapidTable(input_args) viser = VisTable() img_path = "tests/test_files/table.jpg" @@ -21,7 +23,8 @@ boxes, txts, scores = list(zip(*ocr_result)) # Table Rec -table_html_str, table_cell_bboxes, _ = table_engine(img_path, ocr_result) +table_results = table_engine(img_path, ocr_result) +table_html_str, table_cell_bboxes = table_results.pred_html, table_results.cell_bboxes # Save save_dir = Path("outputs") @@ -31,9 +34,7 @@ save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}" # Visualize table rec result -vis_imged = viser( - img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path -) +vis_imged = viser(img_path, table_results, save_html_path, save_drawed_path) # Visualize OCR result save_ocr_path = save_dir / f"{Path(img_path).stem}_ocr_vis{Path(img_path).suffix}" diff --git a/outputs/table.html b/outputs/table.html deleted file mode 100644 index 1d26243..0000000 --- a/outputs/table.html +++ /dev/null @@ -1,4 +0,0 @@ -
MethodsRPFFPS
SegLink[26]70.086.077.08.9
PixelLink[4]73.283.077.8-
TextSnake[18]73.983.278.31.1
TextField[37]75.987.481.35.2
MSR[38]76.787.481.7
FTSN[3]77.187.682.0
LSE[30]81.784.282.9"
CRAFT[2]78.288.282.98.6
MCN[16]798883-
ATRR[35]82.185.283.6"
PAN [34]83.884.484.130.2
DB[12]79.291.584.932.0
DRRG[41]82.3088.0585.08-
Ours (SynText)80.6885.4082.9712.68
Ours (MLT-17)84.5486.6285.5712.31
\ No newline at end of file diff --git a/outputs/table_ocr_vis.jpg b/outputs/table_ocr_vis.jpg deleted file mode 100644 index f79d89e..0000000 Binary files a/outputs/table_ocr_vis.jpg and /dev/null differ diff --git a/outputs/table_table_vis.jpg b/outputs/table_table_vis.jpg deleted file mode 100644 index cfa68c9..0000000 Binary files a/outputs/table_table_vis.jpg and /dev/null differ diff --git a/rapid_table/__init__.py b/rapid_table/__init__.py index 93c6a5d..e152860 100644 --- a/rapid_table/__init__.py +++ b/rapid_table/__init__.py @@ -1,5 +1,5 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from .main import RapidTable -from .utils import VisTable +from .main import RapidTable, RapidTableInput +from .utils.utils import VisTable diff --git a/rapid_table/main.py b/rapid_table/main.py index d620fe5..426a7ce 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -5,33 +5,76 @@ import copy import importlib import time +from dataclasses import asdict, dataclass +from enum import Enum from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import cv2 import numpy as np +from rapid_table.utils.download_model import DownloadModel +from rapid_table.utils.logger import get_logger +from rapid_table.utils.utils import LoadImage, VisTable + from .table_matcher import TableMatch -from .table_structure import TableStructurer -from .utils import LoadImage, VisTable +from .table_structure import TableStructurer, TableStructureUnitable +logger = get_logger("main") root_dir = Path(__file__).resolve().parent +class ModelType(Enum): + PPSTRUCTURE_EN = "ppstructure_en" + PPSTRUCTURE_ZH = "ppstructure_zh" + SLANETPLUS = "slanet_plus" + UNITABLE = "unitable" + + +ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" +KEY_TO_MODEL_URL = { + ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx", + ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx", + ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx", + ModelType.UNITABLE.value: { + "encoder": f"{ROOT_URL}/unitable/encoder.pth", + "decoder": f"{ROOT_URL}/unitable/decoder.pth", + "vocab": f"{ROOT_URL}/unitable/vocab.json", + }, +} + + +@dataclass +class RapidTableInput: + model_type: Optional[str] = ModelType.SLANETPLUS.value + model_path: Union[str, Path, None, Dict[str, str]] = None + use_cuda: bool = False + device: str = "cpu" + + +@dataclass +class RapidTableOutput: + pred_html: Optional[str] = None + cell_bboxes: Optional[np.ndarray] = None + logic_points: Optional[np.ndarray] = None + elapse: Optional[float] = None + + class RapidTable: - def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False): - if model_path is None: - model_path = str( - root_dir / "models" / "slanet-plus.onnx" + def __init__(self, config: RapidTableInput): + self.model_type = config.model_type + if self.model_type not in KEY_TO_MODEL_URL: + model_list = ",".join(KEY_TO_MODEL_URL) + raise ValueError( + f"{self.model_type} is not supported. The currently supported models are {model_list}." ) - model_type = "slanet-plus" - self.model_type = model_type - self.load_img = LoadImage() - config = { - "model_path": model_path, - "use_cuda": use_cuda, - } - self.table_structure = TableStructurer(config) + + config.model_path = self.get_model_path(config.model_type, config.model_path) + if self.model_type == ModelType.UNITABLE.value: + self.table_structure = TableStructureUnitable(asdict(config)) + else: + self.table_structure = TableStructurer(asdict(config)) + self.table_matcher = TableMatch() try: @@ -39,12 +82,13 @@ def __init__(self, model_path: Optional[str] = None, model_type: str = None, use except ModuleNotFoundError: self.ocr_engine = None + self.load_img = LoadImage() + def __call__( self, img_content: Union[str, np.ndarray, bytes, Path], ocr_result: List[Union[List[List[float]], str, str]] = None, - return_logic_points = False - ) -> Tuple[str, float]: + ) -> RapidTableOutput: if self.ocr_engine is None and ocr_result is None: raise ValueError( "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." @@ -52,25 +96,28 @@ def __call__( img = self.load_img(img_content) - s = time.time() + s = time.perf_counter() h, w = img.shape[:2] if ocr_result is None: ocr_result, _ = self.ocr_engine(img) dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w) - pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img)) + pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img)) + # 适配slanet-plus模型输出的box缩放还原 - if self.model_type == "slanet-plus": - pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes) - pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res) - # 避免低版本升级后出现问题,默认不打开 - if return_logic_points: - logic_points = self.table_matcher.decode_logic_points(pred_structures) - elapse = time.time() - s - return pred_html, pred_bboxes, logic_points, elapse - elapse = time.time() - s - return pred_html, pred_bboxes, elapse + if self.model_type == ModelType.SLANETPLUS.value: + cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes) + + pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res) + + # 过滤掉占位的bbox + mask = ~np.all(cell_bboxes == 0, axis=1) + cell_bboxes = cell_bboxes[mask] + + logic_points = self.table_matcher.decode_logic_points(pred_structures) + elapse = time.perf_counter() - s + return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse) def get_boxes_recs( self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int @@ -89,15 +136,39 @@ def get_boxes_recs( r_boxes.append(box) dt_boxes = np.array(r_boxes) return dt_boxes, rec_res - def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray: + + def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray: h, w = img.shape[:2] resized = 488 ratio = min(resized / h, resized / w) w_ratio = resized / (w * ratio) h_ratio = resized / (h * ratio) - pred_bboxes[:, 0::2] *= w_ratio - pred_bboxes[:, 1::2] *= h_ratio - return pred_bboxes + cell_bboxes[:, 0::2] *= w_ratio + cell_bboxes[:, 1::2] *= h_ratio + return cell_bboxes + + @staticmethod + def get_model_path( + model_type: str, model_path: Union[str, Path, None] + ) -> Union[str, Dict[str, str]]: + if model_path is not None: + return model_path + + model_url = KEY_TO_MODEL_URL.get(model_type, None) + if isinstance(model_url, str): + model_path = DownloadModel.download(model_url) + return model_path + + if isinstance(model_url, dict): + model_paths = {} + for k, url in model_url.items(): + model_paths[k] = DownloadModel.download( + url, save_model_name=f"{model_type}_{Path(url).name}" + ) + return model_paths + + raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") + def main(): parser = argparse.ArgumentParser() @@ -105,6 +176,7 @@ def main(): "-v", "--vis", action="store_true", + default=False, help="Wheter to visualize the layout results.", ) parser.add_argument( @@ -112,10 +184,10 @@ def main(): ) parser.add_argument( "-m", - "--model_path", + "--model_type", type=str, - default=str(root_dir / "models" / "en_ppstructure_mobile_v2_SLANet.onnx"), - help="The model path used for inference.", + default=ModelType.SLANETPLUS.value, + choices=list(KEY_TO_MODEL_URL), ) args = parser.parse_args() @@ -126,12 +198,16 @@ def main(): "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." ) from exc - rapid_table = RapidTable(args.model_path) + table_engine = RapidTable(args.model_path) img = cv2.imread(args.img_path) ocr_result, _ = ocr_engine(img) - table_html_str, table_cell_bboxes, elapse = rapid_table(img, ocr_result) + table_results = table_engine(img, ocr_result) + table_html_str, table_cell_bboxes = ( + table_results.pred_html, + table_results.cell_bboxes, + ) print(table_html_str) viser = VisTable() diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py index 0453929..f579a12 100644 --- a/rapid_table/table_matcher/matcher.py +++ b/rapid_table/table_matcher/matcher.py @@ -22,18 +22,18 @@ def __init__(self, filter_ocr_result=True, use_master=False): self.filter_ocr_result = filter_ocr_result self.use_master = use_master - def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res): + def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res): if self.filter_ocr_result: - dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res) - matched_index = self.match_result(dt_boxes, pred_bboxes) + dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res) + matched_index = self.match_result(dt_boxes, cell_bboxes) pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res) return pred_html - def match_result(self, dt_boxes, pred_bboxes): + def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8): matched = {} for i, gt_box in enumerate(dt_boxes): distances = [] - for j, pred_box in enumerate(pred_bboxes): + for j, pred_box in enumerate(cell_bboxes): if len(pred_box) == 8: pred_box = [ np.min(pred_box[0::2]), @@ -49,7 +49,11 @@ def match_result(self, dt_boxes, pred_bboxes): sorted_distances = sorted( sorted_distances, key=lambda item: (item[1], item[0]) ) - if distances.index(sorted_distances[0]) not in matched.keys(): + # must > min_iou + if sorted_distances[0][1] >= 1 - min_iou: + continue + + if distances.index(sorted_distances[0]) not in matched: matched[distances.index(sorted_distances[0])] = [i] else: matched[distances.index(sorted_distances[0])].append(i) @@ -111,6 +115,7 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents): filter_elements = ["", "", "", ""] end_html = [v for v in end_html if v not in filter_elements] return "".join(end_html), end_html + def decode_logic_points(self, pred_structures): logic_points = [] current_row = 0 @@ -131,22 +136,24 @@ def mark_occupied(row, col, rowspan, colspan): while i < len(pred_structures): token = pred_structures[i] - if token == '': + if token == "": current_col = 0 # 每次遇到 时,重置当前列号 - elif token == '': + elif token == "": current_row += 1 # 行结束,行号增加 - elif token .startswith(''): - if 'colspan=' in pred_structures[j]: - colspan = int(pred_structures[j].split('=')[1].strip('"\'')) - elif 'rowspan=' in pred_structures[j]: - rowspan = int(pred_structures[j].split('=')[1].strip('"\'')) + while j < len(pred_structures) and not pred_structures[ + j + ].startswith(">"): + if "colspan=" in pred_structures[j]: + colspan = int(pred_structures[j].split("=")[1].strip("\"'")) + elif "rowspan=" in pred_structures[j]: + rowspan = int(pred_structures[j].split("=")[1].strip("\"'")) j += 1 # 跳过已经处理过的属性 token @@ -179,8 +186,8 @@ def mark_occupied(row, col, rowspan, colspan): return logic_points - def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res): - y1 = pred_bboxes[:, 1::2].min() + def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res): + y1 = cell_bboxes[:, 1::2].min() new_dt_boxes = [] new_rec_res = [] diff --git a/rapid_table/table_matcher/utils.py b/rapid_table/table_matcher/utils.py index 3ec8fcc..57a613c 100644 --- a/rapid_table/table_matcher/utils.py +++ b/rapid_table/table_matcher/utils.py @@ -14,6 +14,7 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import copy import re @@ -48,7 +49,7 @@ def deal_isolate_span(thead_part): spanStr_in_isolateItem = span_part.group() # 3. merge the span number into the span token format string. if spanStr_in_isolateItem is not None: - corrected_item = "".format(spanStr_in_isolateItem) + corrected_item = f"" corrected_list.append(corrected_item) else: corrected_list.append(None) @@ -243,6 +244,6 @@ def compute_iou(rec1, rec2): # judge if there is an intersect if left_line >= right_line or top_line >= bottom_line: return 0.0 - else: - intersect = (right_line - left_line) * (bottom_line - top_line) - return (intersect / (sum_area - intersect)) * 1.0 + + intersect = (right_line - left_line) * (bottom_line - top_line) + return (intersect / (sum_area - intersect)) * 1.0 diff --git a/rapid_table/table_structure/__init__.py b/rapid_table/table_structure/__init__.py index a548391..3e638d9 100644 --- a/rapid_table/table_structure/__init__.py +++ b/rapid_table/table_structure/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. from .table_structure import TableStructurer +from .table_structure_unitable import TableStructureUnitable diff --git a/rapid_table/table_structure/table_structure.py b/rapid_table/table_structure/table_structure.py index c2509d8..9152603 100644 --- a/rapid_table/table_structure/table_structure.py +++ b/rapid_table/table_structure/table_structure.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from typing import Dict, Any +from typing import Any, Dict import numpy as np diff --git a/rapid_table/table_structure/table_structure_unitable.py b/rapid_table/table_structure/table_structure_unitable.py new file mode 100644 index 0000000..2b98006 --- /dev/null +++ b/rapid_table/table_structure/table_structure_unitable.py @@ -0,0 +1,229 @@ +import re +import time + +import cv2 +import numpy as np +import torch +from PIL import Image +from tokenizers import Tokenizer +from torchvision import transforms + +from .unitable_modules import Encoder, GPTFastDecoder + +IMG_SIZE = 448 +EOS_TOKEN = "" +BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)] + +HTML_BBOX_HTML_TOKENS = [ + "", + "[", + "]", + "[", + ">", + "", + "", + "", + "", + "", + "", + ' rowspan="2"', + ' rowspan="3"', + ' rowspan="4"', + ' rowspan="5"', + ' rowspan="6"', + ' rowspan="7"', + ' rowspan="8"', + ' rowspan="9"', + ' rowspan="10"', + ' rowspan="11"', + ' rowspan="12"', + ' rowspan="13"', + ' rowspan="14"', + ' rowspan="15"', + ' rowspan="16"', + ' rowspan="17"', + ' rowspan="18"', + ' rowspan="19"', + ' colspan="2"', + ' colspan="3"', + ' colspan="4"', + ' colspan="5"', + ' colspan="6"', + ' colspan="7"', + ' colspan="8"', + ' colspan="9"', + ' colspan="10"', + ' colspan="11"', + ' colspan="12"', + ' colspan="13"', + ' colspan="14"', + ' colspan="15"', + ' colspan="16"', + ' colspan="17"', + ' colspan="18"', + ' colspan="19"', + ' colspan="25"', +] + +VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS +TASK_TOKENS = [ + "[table]", + "[html]", + "[cell]", + "[bbox]", + "[cell+bbox]", + "[html+bbox]", +] + + +class TableStructureUnitable: + def __init__(self, config): + # encoder_path: str, decoder_path: str, vocab_path: str, device: str + vocab_path = config["model_path"]["vocab"] + encoder_path = config["model_path"]["encoder"] + decoder_path = config["model_path"]["decoder"] + device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu" + + self.vocab = Tokenizer.from_file(vocab_path) + self.token_white_list = [ + self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS + ] + self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS) + self.bbox_close_html_token = self.vocab.token_to_id("]") + self.prefix_token_id = self.vocab.token_to_id("[html+bbox]") + self.eos_id = self.vocab.token_to_id(EOS_TOKEN) + self.max_seq_len = 1024 + self.device = device + self.img_size = IMG_SIZE + + # init encoder + encoder_state_dict = torch.load(encoder_path, map_location=device) + self.encoder = Encoder() + self.encoder.load_state_dict(encoder_state_dict) + self.encoder.eval().to(device) + + # init decoder + decoder_state_dict = torch.load(decoder_path, map_location=device) + self.decoder = GPTFastDecoder() + self.decoder.load_state_dict(decoder_state_dict) + self.decoder.eval().to(device) + + # define img transform + self.transform = transforms.Compose( + [ + transforms.Resize((448, 448)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.86597056, 0.88463002, 0.87491087], + std=[0.20686628, 0.18201602, 0.18485524], + ), + ] + ) + + @torch.inference_mode() + def __call__(self, image: np.ndarray): + start_time = time.time() + ori_h, ori_w = image.shape[:2] + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = Image.fromarray(image) + image = self.transform(image).unsqueeze(0).to(self.device) + self.decoder.setup_caches( + max_batch_size=1, + max_seq_length=self.max_seq_len, + dtype=image.dtype, + device=self.device, + ) + context = ( + torch.tensor([self.prefix_token_id], dtype=torch.int32) + .repeat(1, 1) + .to(self.device) + ) + eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device) + memory = self.encoder(image) + context = self.loop_decode(context, eos_id_tensor, memory) + bboxes, html_tokens = self.decode_tokens(context) + bboxes = bboxes.astype(np.float32) + + # rescale boxes + scale_h = ori_h / self.img_size + scale_w = ori_w / self.img_size + bboxes[:, 0::2] *= scale_w # 缩放 x 坐标 + bboxes[:, 1::2] *= scale_h # 缩放 y 坐标 + bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1) + bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1) + structure_str_list = ( + ["", "", ""] + + html_tokens + + ["
", "", ""] + ) + return structure_str_list, bboxes, time.time() - start_time + + def decode_tokens(self, context): + pred_html = context[0] + pred_html = pred_html.detach().cpu().numpy() + pred_html = self.vocab.decode(pred_html, skip_special_tokens=False) + seq = pred_html.split("")[0] + token_black_list = ["", "", *TASK_TOKENS] + for i in token_black_list: + seq = seq.replace(i, "") + + tr_pattern = re.compile(r"(.*?)", re.DOTALL) + td_pattern = re.compile(r"(.*?)", re.DOTALL) + bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]") + + decoded_list = [] + bbox_coords = [] + + # 查找所有的 标签 + for tr_match in tr_pattern.finditer(pred_html): + tr_content = tr_match.group(1) + decoded_list.append("") + + # 查找所有的 标签 + for td_match in td_pattern.finditer(tr_content): + td_attrs = td_match.group(1).strip() + td_content = td_match.group(2).strip() + if td_attrs: + decoded_list.append("") + decoded_list.append("") + else: + decoded_list.append("") + + # 查找 bbox 坐标 + bbox_match = bbox_pattern.search(td_content) + if bbox_match: + xmin, ymin, xmax, ymax = map(int, bbox_match.groups()) + # 将坐标转换为从左上角开始顺时针到左下角的点的坐标 + coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax]) + bbox_coords.append(coords) + else: + # 填充占位的bbox,保证后续流程统一 + bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0])) + decoded_list.append("") + + bbox_coords_array = np.array(bbox_coords) + return bbox_coords_array, decoded_list + + def loop_decode(self, context, eos_id_tensor, memory): + box_token_count = 0 + for _ in range(self.max_seq_len): + eos_flag = (context == eos_id_tensor).any(dim=1) + if torch.all(eos_flag): + break + + next_tokens = self.decoder(memory, context) + if next_tokens[0] in self.bbox_token_ids: + box_token_count += 1 + if box_token_count > 4: + next_tokens = torch.tensor( + [self.bbox_close_html_token], dtype=torch.int32 + ) + box_token_count = 0 + context = torch.cat([context, next_tokens], dim=1) + return context diff --git a/rapid_table/table_structure/unitable_modules.py b/rapid_table/table_structure/unitable_modules.py new file mode 100644 index 0000000..5b8dac3 --- /dev/null +++ b/rapid_table/table_structure/unitable_modules.py @@ -0,0 +1,911 @@ +from dataclasses import dataclass +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import functional as F +from torch.nn.modules.transformer import _get_activation_fn + +TOKEN_WHITE_LIST = [ + 1, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 27, + 28, + 29, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + 50, + 51, + 52, + 53, + 54, + 55, + 56, + 57, + 58, + 59, + 60, + 61, + 62, + 63, + 64, + 65, + 66, + 67, + 68, + 69, + 70, + 71, + 72, + 73, + 74, + 75, + 76, + 77, + 78, + 79, + 80, + 81, + 82, + 83, + 84, + 85, + 86, + 87, + 88, + 89, + 90, + 91, + 92, + 93, + 94, + 95, + 96, + 97, + 98, + 99, + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 107, + 108, + 109, + 110, + 111, + 112, + 113, + 114, + 115, + 116, + 117, + 118, + 119, + 120, + 121, + 122, + 123, + 124, + 125, + 126, + 127, + 128, + 129, + 130, + 131, + 132, + 133, + 134, + 135, + 136, + 137, + 138, + 139, + 140, + 141, + 142, + 143, + 144, + 145, + 146, + 147, + 148, + 149, + 150, + 151, + 152, + 153, + 154, + 155, + 156, + 157, + 158, + 159, + 160, + 161, + 162, + 163, + 164, + 165, + 166, + 167, + 168, + 169, + 170, + 171, + 172, + 173, + 174, + 175, + 176, + 177, + 178, + 179, + 180, + 181, + 182, + 183, + 184, + 185, + 186, + 187, + 188, + 189, + 190, + 191, + 192, + 193, + 194, + 195, + 196, + 197, + 198, + 199, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 211, + 212, + 213, + 214, + 215, + 216, + 217, + 218, + 219, + 220, + 221, + 222, + 223, + 224, + 225, + 226, + 227, + 228, + 229, + 230, + 231, + 232, + 233, + 234, + 235, + 236, + 237, + 238, + 239, + 240, + 241, + 242, + 243, + 244, + 245, + 246, + 247, + 248, + 249, + 250, + 251, + 252, + 253, + 254, + 255, + 256, + 257, + 258, + 259, + 260, + 261, + 262, + 263, + 264, + 265, + 266, + 267, + 268, + 269, + 270, + 271, + 272, + 273, + 274, + 275, + 276, + 277, + 278, + 279, + 280, + 281, + 282, + 283, + 284, + 285, + 286, + 287, + 288, + 289, + 290, + 291, + 292, + 293, + 294, + 295, + 296, + 297, + 298, + 299, + 300, + 301, + 302, + 303, + 304, + 305, + 306, + 307, + 308, + 309, + 310, + 311, + 312, + 313, + 314, + 315, + 316, + 317, + 318, + 319, + 320, + 321, + 322, + 323, + 324, + 325, + 326, + 327, + 328, + 329, + 330, + 331, + 332, + 333, + 334, + 335, + 336, + 337, + 338, + 339, + 340, + 341, + 342, + 343, + 344, + 345, + 346, + 347, + 348, + 349, + 350, + 351, + 352, + 353, + 354, + 355, + 356, + 357, + 358, + 359, + 360, + 361, + 362, + 363, + 364, + 365, + 366, + 367, + 368, + 369, + 370, + 371, + 372, + 373, + 374, + 375, + 376, + 377, + 378, + 379, + 380, + 381, + 382, + 383, + 384, + 385, + 386, + 387, + 388, + 389, + 390, + 391, + 392, + 393, + 394, + 395, + 396, + 397, + 398, + 399, + 400, + 401, + 402, + 403, + 404, + 405, + 406, + 407, + 408, + 409, + 410, + 411, + 412, + 413, + 414, + 415, + 416, + 417, + 418, + 419, + 420, + 421, + 422, + 423, + 424, + 425, + 426, + 427, + 428, + 429, + 430, + 431, + 432, + 433, + 434, + 435, + 436, + 437, + 438, + 439, + 440, + 441, + 442, + 443, + 444, + 445, + 446, + 447, + 448, + 449, + 450, + 451, + 452, + 453, + 454, + 455, + 456, + 457, + 458, + 459, + 460, + 461, + 462, + 463, + 464, + 465, + 466, + 467, + 468, + 469, + 470, + 471, + 472, + 473, + 474, + 475, + 476, + 477, + 478, + 479, + 480, + 481, + 482, + 483, + 484, + 485, + 486, + 487, + 488, + 489, + 490, + 491, + 492, + 493, + 494, + 495, + 496, + 497, + 498, + 499, + 500, + 501, + 502, + 503, + 504, + 505, + 506, + 507, + 508, + 509, +] + + +class ImgLinearBackbone(nn.Module): + def __init__( + self, + d_model: int, + patch_size: int, + in_chan: int = 3, + ) -> None: + super().__init__() + + self.conv_proj = nn.Conv2d( + in_chan, + out_channels=d_model, + kernel_size=patch_size, + stride=patch_size, + ) + self.d_model = d_model + + def forward(self, x: Tensor) -> Tensor: + x = self.conv_proj(x) + x = x.flatten(start_dim=-2).transpose(1, 2) + return x + + +class Encoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + self.patch_size = 16 + self.d_model = 768 + self.dropout = 0 + self.activation = "gelu" + self.norm_first = True + self.ff_ratio = 4 + self.nhead = 12 + self.max_seq_len = 1024 + self.n_encoder_layer = 12 + encoder_layer = nn.TransformerEncoderLayer( + self.d_model, + nhead=self.nhead, + dim_feedforward=self.ff_ratio * self.d_model, + dropout=self.dropout, + activation=self.activation, + batch_first=True, + norm_first=self.norm_first, + ) + norm_layer = partial(nn.LayerNorm, eps=1e-6) + self.norm = norm_layer(self.d_model) + self.backbone = ImgLinearBackbone( + d_model=self.d_model, patch_size=self.patch_size + ) + self.pos_embed = PositionEmbedding( + max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout + ) + self.encoder = nn.TransformerEncoder( + encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False + ) + + def forward(self, x: Tensor) -> Tensor: + src_feature = self.backbone(x) + src_feature = self.pos_embed(src_feature) + memory = self.encoder(src_feature) + memory = self.norm(memory) + return memory + + +class PositionEmbedding(nn.Module): + def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None: + super().__init__() + self.embedding = nn.Embedding(max_seq_len, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + # assume x is batch first + if input_pos is None: + _pos = torch.arange(x.shape[1], device=x.device) + else: + _pos = input_pos + out = self.embedding(_pos) + return self.dropout(out + x) + + +class TokenEmbedding(nn.Module): + def __init__( + self, + vocab_size: int, + d_model: int, + padding_idx: int, + ) -> None: + super().__init__() + assert vocab_size > 0 + self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx) + + def forward(self, x: Tensor) -> Tensor: + return self.embedding(x) + + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + + +@dataclass +class ModelArgs: + n_layer: int = 4 + n_head: int = 12 + dim: int = 768 + intermediate_size: int = None + head_dim: int = 64 + activation: str = "gelu" + norm_first: bool = True + + def __post_init__(self): + if self.intermediate_size is None: + hidden_dim = 4 * self.dim + n_hidden = int(2 * hidden_dim / 3) + self.intermediate_size = find_multiple(n_hidden, 256) + self.head_dim = self.dim // self.n_head + + +class KVCache(nn.Module): + def __init__( + self, + max_batch_size, + max_seq_length, + n_heads, + head_dim, + dtype=torch.bfloat16, + device="cpu", + ): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) + self.register_buffer( + "k_cache", + torch.zeros(cache_shape, dtype=dtype, device=device), + persistent=False, + ) + self.register_buffer( + "v_cache", + torch.zeros(cache_shape, dtype=dtype, device=device), + persistent=False, + ) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + # assert input_pos.shape[0] == k_val.shape[2] + + bs = k_val.shape[0] + k_out = self.k_cache + v_out = self.v_cache + k_out[:bs, :, input_pos] = k_val + v_out[:bs, :, input_pos] = v_val + + return k_out[:bs], v_out[:bs] + + +class GPTFastDecoder(nn.Module): + def __init__(self) -> None: + super().__init__() + + self.vocab_size = 960 + self.padding_idx = 2 + self.prefix_token_id = 11 + self.eos_id = 1 + self.max_seq_len = 1024 + self.dropout = 0 + self.d_model = 768 + self.nhead = 12 + self.activation = "gelu" + self.norm_first = True + self.n_decoder_layer = 4 + config = ModelArgs( + n_layer=self.n_decoder_layer, + n_head=self.nhead, + dim=self.d_model, + intermediate_size=self.d_model * 4, + activation=self.activation, + norm_first=self.norm_first, + ) + self.config = config + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layer) + ) + self.token_embed = TokenEmbedding( + vocab_size=self.vocab_size, + d_model=self.d_model, + padding_idx=self.padding_idx, + ) + self.pos_embed = PositionEmbedding( + max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout + ) + self.generator = nn.Linear(self.d_model, self.vocab_size) + self.token_white_list = TOKEN_WHITE_LIST + self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length, dtype, device): + for b in self.layers: + b.multihead_attn.k_cache = None + b.multihead_attn.v_cache = None + + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_head + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + + for b in self.layers: + b.self_attn.kv_cache = KVCache( + max_batch_size, + max_seq_length, + self.config.n_head, + head_dim, + dtype, + device, + ) + b.multihead_attn.k_cache = None + b.multihead_attn.v_cache = None + + self.causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ).to(device) + + def forward(self, memory: Tensor, tgt: Tensor) -> Tensor: + input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int) + tgt = tgt[:, -1:] + tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos) + # tgt = self.decoder(tgt_feature, memory, input_pos) + with torch.backends.cuda.sdp_kernel( + enable_flash=False, enable_mem_efficient=False, enable_math=True + ): + logits = tgt_feature + tgt_mask = self.causal_mask[None, None, input_pos] + for i, layer in enumerate(self.layers): + logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask) + # return output + logits = self.generator(logits)[:, -1, :] + total = set([i for i in range(logits.shape[-1])]) + black_list = list(total.difference(set(self.token_white_list))) + logits[..., black_list] = -1e9 + probs = F.softmax(logits, dim=-1) + _, next_tokens = probs.topk(1) + return next_tokens + + +class TransformerBlock(nn.Module): + def __init__(self, config: ModelArgs) -> None: + super().__init__() + self.self_attn = Attention(config) + self.multihead_attn = CrossAttention(config) + + layer_norm_eps = 1e-5 + + d_model = config.dim + dim_feedforward = config.intermediate_size + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm_first = config.norm_first + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) + + self.activation = _get_activation_fn(config.activation) + + def forward( + self, + tgt: Tensor, + memory: Tensor, + tgt_mask: Tensor, + input_pos: Tensor, + ) -> Tensor: + if self.norm_first: + x = tgt + x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos) + x = x + self.multihead_attn(self.norm2(x), memory) + x = x + self._ff_block(self.norm3(x)) + else: + x = tgt + x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos)) + x = self.norm2(x + self.multihead_attn(x, memory)) + x = self.norm3(x + self._ff_block(x)) + return x + + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.activation(self.linear1(x))) + return x + + +class Attention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + # key, query, value projections for all heads, but in a batch + self.wqkv = nn.Linear(config.dim, 3 * config.dim) + self.wo = nn.Linear(config.dim, config.dim) + + self.kv_cache: Optional[KVCache] = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + self.dim = config.dim + + def forward( + self, + x: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + kv_size = self.n_head * self.head_dim + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, self.n_head, self.head_dim) + k = k.view(bsz, seqlen, self.n_head, self.head_dim) + v = v.view(bsz, seqlen, self.n_head, self.head_dim) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + if self.kv_cache is not None: + k, v = self.kv_cache.update(input_pos, k, v) + + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) + + y = self.wo(y) + return y + + +class CrossAttention(nn.Module): + def __init__(self, config: ModelArgs): + super().__init__() + assert config.dim % config.n_head == 0 + + self.query = nn.Linear(config.dim, config.dim) + self.key = nn.Linear(config.dim, config.dim) + self.value = nn.Linear(config.dim, config.dim) + self.out = nn.Linear(config.dim, config.dim) + + self.k_cache = None + self.v_cache = None + + self.n_head = config.n_head + self.head_dim = config.head_dim + + def get_kv(self, xa: torch.Tensor): + if self.k_cache is not None and self.v_cache is not None: + return self.k_cache, self.v_cache + + k = self.key(xa) + v = self.value(xa) + + # Reshape for correct format + batch_size, source_seq_len, _ = k.shape + k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim) + v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim) + + if self.k_cache is None: + self.k_cache = k + if self.v_cache is None: + self.v_cache = v + + return k, v + + def forward( + self, + x: Tensor, + xa: Tensor, + ): + q = self.query(x) + batch_size, target_seq_len, _ = q.shape + q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim) + k, v = self.get_kv(xa) + + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) + + wv = F.scaled_dot_product_attention( + query=q, + key=k, + value=v, + is_causal=False, + ) + wv = wv.transpose(1, 2).reshape( + batch_size, + target_seq_len, + self.n_head * self.head_dim, + ) + + return self.out(wv) diff --git a/rapid_table/table_structure/utils.py b/rapid_table/table_structure/utils.py index c88cf3c..78f41d2 100644 --- a/rapid_table/table_structure/utils.py +++ b/rapid_table/table_structure/utils.py @@ -17,12 +17,11 @@ import os import platform import traceback - -import cv2 from enum import Enum from pathlib import Path from typing import Any, Dict, List, Tuple, Union +import cv2 import numpy as np from onnxruntime import ( GraphOptimizationLevel, @@ -32,7 +31,7 @@ get_device, ) -from rapid_table.table_structure.logger import get_logger +from rapid_table.utils.logger import get_logger class EP(Enum): @@ -79,10 +78,12 @@ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: sess_opt.inter_op_num_threads = inter_op_num_threads return sess_opt + def get_metadata(self, key: str = "character") -> list: meta_dict = self.session.get_modelmeta().custom_metadata_map content_list = meta_dict[key].splitlines() return content_list + def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: cpu_provider_opts = { "arena_extend_strategy": "kSameAsRequested", diff --git a/rapid_table/utils/__init__.py b/rapid_table/utils/__init__.py new file mode 100644 index 0000000..0ecdd4f --- /dev/null +++ b/rapid_table/utils/__init__.py @@ -0,0 +1,3 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com diff --git a/rapid_table/utils/download_model.py b/rapid_table/utils/download_model.py new file mode 100644 index 0000000..adedb5d --- /dev/null +++ b/rapid_table/utils/download_model.py @@ -0,0 +1,67 @@ +import io +from pathlib import Path +from typing import Optional, Union + +import requests +from tqdm import tqdm + +from .logger import get_logger + +logger = get_logger("DownloadModel") + +PROJECT_DIR = Path(__file__).resolve().parent.parent +DEFAULT_MODEL_DIR = PROJECT_DIR / "models" + + +class DownloadModel: + @classmethod + def download( + cls, + model_full_url: Union[str, Path], + save_dir: Union[str, Path, None] = None, + save_model_name: Optional[str] = None, + ) -> str: + if save_dir is None: + save_dir = DEFAULT_MODEL_DIR + + save_dir.mkdir(parents=True, exist_ok=True) + + if save_model_name is None: + save_model_name = Path(model_full_url).name + + save_file_path = save_dir / save_model_name + if save_file_path.exists(): + logger.debug("%s already exists", save_file_path) + return str(save_file_path) + + try: + logger.info("Download %s to %s", model_full_url, save_dir) + file = cls.download_as_bytes_with_progress(model_full_url, save_model_name) + cls.save_file(save_file_path, file) + except Exception as exc: + raise DownloadModelError from exc + return str(save_file_path) + + @staticmethod + def download_as_bytes_with_progress( + url: Union[str, Path], name: Optional[str] = None + ) -> bytes: + resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) + total = int(resp.headers.get("content-length", 0)) + bio = io.BytesIO() + with tqdm( + desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 + ) as pbar: + for chunk in resp.iter_content(chunk_size=65536): + pbar.update(len(chunk)) + bio.write(chunk) + return bio.getvalue() + + @staticmethod + def save_file(save_path: Union[str, Path], file: bytes): + with open(save_path, "wb") as f: + f.write(file) + + +class DownloadModelError(Exception): + pass diff --git a/rapid_table/table_structure/logger.py b/rapid_table/utils/logger.py similarity index 100% rename from rapid_table/table_structure/logger.py rename to rapid_table/utils/logger.py diff --git a/rapid_table/utils.py b/rapid_table/utils/utils.py similarity index 85% rename from rapid_table/utils.py rename to rapid_table/utils/utils.py index 350f2e5..ad35a19 100644 --- a/rapid_table/utils.py +++ b/rapid_table/utils/utils.py @@ -4,7 +4,7 @@ import os from io import BytesIO from pathlib import Path -from typing import Optional, Union, List +from typing import Optional, Union import cv2 import numpy as np @@ -14,9 +14,7 @@ class LoadImage: - def __init__( - self, - ): + def __init__(self): pass def __call__(self, img: InputType) -> np.ndarray: @@ -79,25 +77,22 @@ class LoadImageError(Exception): class VisTable: - def __init__( - self, - ): + def __init__(self): self.load_img = LoadImage() def __call__( - self, - img_path: Union[str, Path], - table_html_str: str, - save_html_path: Optional[str] = None, - table_cell_bboxes: Optional[np.ndarray] = None, - save_drawed_path: Optional[str] = None, - logic_points: List[List[float]] = None, - save_logic_path: Optional[str] = None, - ) -> None: + self, + img_path: Union[str, Path], + table_results, + save_html_path: Optional[str] = None, + save_drawed_path: Optional[str] = None, + save_logic_path: Optional[str] = None, + ): if save_html_path: - html_with_border = self.insert_border_style(table_html_str) + html_with_border = self.insert_border_style(table_results.pred_html) self.save_html(save_html_path, html_with_border) + table_cell_bboxes = table_results.cell_bboxes if table_cell_bboxes is None: return None @@ -113,31 +108,37 @@ def __call__( if save_drawed_path: self.save_img(save_drawed_path, drawed_img) + if save_logic_path and logic_points: - polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes] - self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons) + polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes] + self.plot_rec_box_with_logic_info( + img_path, save_logic_path, logic_points, polygons + ) return drawed_img def insert_border_style(self, table_html_str: str): - style_res = f"""""" + prefix_table, suffix_table = table_html_str.split("") html_with_border = f"{prefix_table}{style_res}{suffix_table}" return html_with_border - def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons): + def plot_rec_box_with_logic_info( + self, img_path, output_path, logic_points, sorted_polygons + ): """ :param img_path :param output_path @@ -182,7 +183,7 @@ def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sort ) os.makedirs(os.path.dirname(output_path), exist_ok=True) # 保存绘制后的图像 - cv2.imwrite(output_path, img) + self.save_img(output_path, img) @staticmethod def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray: diff --git a/setup.py b/setup.py index 0539473..caf2114 100644 --- a/setup.py +++ b/setup.py @@ -3,11 +3,18 @@ # @Contact: liekkaskono@163.com import sys from pathlib import Path +from typing import List, Union import setuptools from get_pypi_latest_version import GetPyPiLatestVersion +def read_txt(txt_path: Union[Path, str]) -> List[str]: + with open(txt_path, "r", encoding="utf-8") as f: + data = [v.rstrip("\n") for v in f] + return data + + def get_readme(): root_dir = Path(__file__).resolve().parent readme_path = str(root_dir / "docs" / "doc_whl_rapid_table.md") @@ -19,7 +26,7 @@ def get_readme(): MODULE_NAME = "rapid_table" obtainer = GetPyPiLatestVersion() latest_version = obtainer(MODULE_NAME) -VERSION_NUM = obtainer.version_add_one(latest_version) +VERSION_NUM = obtainer.version_add_one(latest_version, add_patch=True) if len(sys.argv) > 2: match_str = " ".join(sys.argv[2:]) @@ -34,19 +41,13 @@ def get_readme(): platforms="Any", long_description=get_readme(), long_description_content_type="text/markdown", - description="Tools for parsing table structures based ONNXRuntime.", + description="Table Recognition", author="SWHL", author_email="liekkaskono@163.com", - url="https://github.com/RapidAI/RapidStructure", + url="https://github.com/RapidAI/RapidTable", license="Apache-2.0", include_package_data=True, - install_requires=[ - "onnxruntime>=1.7.0", - "PyYAML>=6.0", - "opencv_python>=4.5.1.48", - "numpy>=1.21.6", - "Pillow", - ], + install_requires=read_txt("requirements.txt"), packages=[ MODULE_NAME, f"{MODULE_NAME}.models", @@ -66,4 +67,5 @@ def get_readme(): ], python_requires=">=3.6,<3.13", entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]}, + extras_require={"torch": ["torch", "torchvision", "tokenizers"]}, ) diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..1138ef1 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,43 @@ +# -*- encoding: utf-8 -*- +# @Author: SWHL +# @Contact: liekkaskono@163.com +import sys +from pathlib import Path + +import pytest +from rapidocr_onnxruntime import RapidOCR + +cur_dir = Path(__file__).resolve().parent +root_dir = cur_dir.parent + +sys.path.append(str(root_dir)) + +from rapid_table import RapidTable, RapidTableInput + +ocr_engine = RapidOCR() + +input_args = RapidTableInput() +table_engine = RapidTable(input_args) + +test_file_dir = cur_dir / "test_files" +img_path = str(test_file_dir / "table.jpg") + + +@pytest.mark.parametrize("model_type", ["slanet_plus", "unitable"]) +def test_ocr_input(model_type): + ocr_res, _ = ocr_engine(img_path) + + input_args = RapidTableInput(model_type=model_type) + table_engine = RapidTable(input_args) + + table_results = table_engine(img_path, ocr_res) + assert table_results.pred_html.count("") == 16 + + +@pytest.mark.parametrize("model_type", ["slanet_plus", "unitable"]) +def test_input_ocr_none(model_type): + input_args = RapidTableInput(model_type=model_type) + table_engine = RapidTable(input_args) + table_results = table_engine(img_path) + assert table_results.pred_html.count("") == 16 + assert len(table_results.cell_bboxes) == len(table_results.logic_points) diff --git a/tests/test_table.py b/tests/test_table.py deleted file mode 100644 index 41a2060..0000000 --- a/tests/test_table.py +++ /dev/null @@ -1,35 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -import sys -from pathlib import Path - -from rapidocr_onnxruntime import RapidOCR - -cur_dir = Path(__file__).resolve().parent -root_dir = cur_dir.parent - -sys.path.append(str(root_dir)) - -from rapid_table import RapidTable - -ocr_engine = RapidOCR() -table_engine = RapidTable() - -test_file_dir = cur_dir / "test_files" -img_path = str(test_file_dir / "table.jpg") - - -def test_ocr_input(): - ocr_res, _ = ocr_engine(img_path) - table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_res) - assert table_html_str.count("") == 16 - - -def test_input_ocr_none(): - table_html_str, table_cell_bboxes, elapse = table_engine(img_path) - assert table_html_str.count("") == 16 - -def test_logic_points_out(): - table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, return_logic_points=True) - assert len(table_cell_bboxes) == len(logic_points)