diff --git a/.github/workflows/publish_whl.yml b/.github/workflows/publish_whl.yml
index 99e137c..c2f6b44 100644
--- a/.github/workflows/publish_whl.yml
+++ b/.github/workflows/publish_whl.yml
@@ -6,7 +6,7 @@ on:
- v*
env:
- RESOURCES_URL: https://github.com/RapidAI/RapidTable/releases/download/assets/rapid_table_models.zip
+ DEFAULT_MODEL: https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/slanet-plus.onnx
jobs:
UnitTesting:
@@ -26,16 +26,15 @@ jobs:
- name: Unit testings
run: |
- wget $RESOURCES_URL
- ZIP_NAME=${RESOURCES_URL##*/}
- DIR_NAME=${ZIP_NAME%.*}
- unzip $DIR_NAME
- cp $DIR_NAME/*.onnx rapid_table/models/
+ wget $DEFAULT_MODEL -P rapid_table/models
pip install -r requirements.txt
pip install rapidocr_onnxruntime
+ pip install torch
+ pip install torchvision
+ pip install tokenizers
pip install pytest
- pytest tests/test_table.py
+ pytest tests/test_main.py
GenerateWHL_PushPyPi:
needs: UnitTesting
@@ -55,11 +54,8 @@ jobs:
pip install -r requirements.txt
python -m pip install --upgrade pip
pip install wheel get_pypi_latest_version
- wget $RESOURCES_URL
- ZIP_NAME=${RESOURCES_URL##*/}
- DIR_NAME=${ZIP_NAME%.*}
- unzip $ZIP_NAME
- mv $DIR_NAME/slanet-plus.onnx rapid_table/models/
+
+ wget $DEFAULT_MODEL -P rapid_table/models
python setup.py bdist_wheel ${{ github.ref_name }}
- name: Publish distribution 📦 to PyPI
diff --git a/.gitignore b/.gitignore
index edaddfb..eee50b9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,6 @@
+outputs/
+*.json
+
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
@@ -156,3 +159,6 @@ long1.jpg
*.pdmodel
.DS_Store
+*.pth
+/rapid_table_torch/models/*.pth
+/rapid_table_torch/models/*.json
diff --git a/LICENSE b/LICENSE
index a13bf2b..16c91c5 100644
--- a/LICENSE
+++ b/LICENSE
@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright 2024 RapidAI
+ Copyright 2025 RapidAI
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/README.md b/README.md
index 6a62216..cafcd4a 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,8 @@
📊 Rapid Table
-
+
+
@@ -15,100 +16,90 @@
### 简介
-RapidTable库是专门用来文档类图像的表格结构还原,结合RapidOCR,将给定图像中的表格转化对应的HTML格式。
-
-目前支持两种类别的表格识别模型:中文和英文表格识别模型,具体可参见下面表格:
+RapidTable库是专门用来文档类图像的表格结构还原,表格结构模型均属于序列预测方法,结合RapidOCR,将给定图像中的表格转化对应的HTML格式。
slanet_plus是paddlex内置的SLANet升级版模型,准确率有大幅提升
- | 模型类型 | 模型名称 | 模型大小 |
- |:--------------:|:--------------------------------------:| :------: |
- | 英文 | `en_ppstructure_mobile_v2_SLANet.onnx` | 7.3M |
- | 中文 | `ch_ppstructure_mobile_v2_SLANet.onnx` | 7.4M |
- | slanet_plus 中文 | `slanet-plus.onnx` | 6.8M |
-
+unitable是来源unitable的transformer模型,精度最高,暂仅支持pytorch推理,支持gpu推理加速,训练权重来源于 [OhMyTable项目](https://github.com/Sanster/OhMyTable)
-模型来源:[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)
-[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)
+### 最近动态
-模型下载地址为:[link](https://github.com/RapidAI/RapidTable/releases/tag/assets)
+2025-01-09 update: 发布v1.0.2,全新接口升级。 \
+2024.12.30 update:支持Unitable模型的表格识别,使用pytorch框架 \
+2024.11.24 update:支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化 \
+2024.10.13 update:补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
### 效果展示
-

+
-### 更新日志
-
-
-
-#### 2024.11.24 update
-- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化
-
-#### 2024.10.13 update
-- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
-
-#### 2023-12-29 v0.1.3 update
-
-- 优化可视化结果部分
-
-#### 2023-12-27 v0.1.2 update
-
-- 添加返回cell坐标框参数
-- 完善可视化函数
-
-#### 2023-07-17 v0.1.0 update
-
-- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
+### 模型列表
-- 增加接口输入参数`ocr_result`:
- - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
- - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
- - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
+| `model_type` | 模型名称 | 推理框架 |模型大小 |推理耗时(单图 60KB)|
+ |:--------------|:--------------------------------------| :------: |:------ |:------ |
+| `ppstructure_en` | `en_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.3M |0.15s |
+| `ppstructure_zh` | `ch_ppstructure_mobile_v2_SLANet.onnx` | onnxruntime |7.4M |0.15s |
+| `slanet_plus` | `slanet-plus.onnx` | onnxruntime |6.8M |0.15s |
+| `unitable` | `unitable(encoder.pth,decoder.pth)` | pytorch |500M |cpu(6s) gpu-4090(1.5s)|
-#### 2023-07-10 v0.0.13 updata
-
-- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
+模型来源\
+[PaddleOCR 表格识别](https://github.com/PaddlePaddle/PaddleOCR/blob/133d67f27dc8a241d6b2e30a9f047a0fb75bebbe/ppstructure/table/README_ch.md)\
+[PaddleX-SlaNetPlus 表格识别](https://github.com/PaddlePaddle/PaddleX/blob/release/3.0-beta1/docs/module_usage/tutorials/ocr_modules/table_structure_recognition.md)\
+[Unitable](https://github.com/poloclub/unitable?tab=readme-ov-file)
-#### 2023-07-06 v0.0.12 update
-
-- 去掉返回表格的html字符串中的``元素,便于后续统一。
-- 采用Black工具优化代码
-
-
-
-
-### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系
-
-TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。
-
-RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`。
-
-总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。
-
-关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果)
+模型下载地址:[link](https://www.modelscope.cn/models/RapidAI/RapidTable/files)
### 安装
-由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。
+由于模型较小,预先将slanet-plus表格识别模型(`slanet-plus.onnx`)打包进了whl包内。其余模型在初始化`RapidTable`类时,会根据`model_type`来自动下载模型到安装包所在`models`目录下。当然也可以通过`RapidTableInput(model_path='')`来指定自己模型路径。注意仅限于我们现支持的`model_type`。
> ⚠️注意:`rapid_table>=v0.1.0`之后,不再将`rapidocr_onnxruntime`依赖强制打包到`rapid_table`中。使用前,需要自行安装`rapidocr_onnxruntime`包。
```bash
pip install rapidocr_onnxruntime
pip install rapid_table
-#pip install onnxruntime-gpu # for gpu inference
+
+# 基于torch来推理unitable模型
+pip install rapid_table[torch] # for unitable inference
+
+# onnxruntime-gpu推理
+pip uninstall onnxruntime
+pip install onnxruntime-gpu # for onnx gpu inference
```
### 使用方式
#### python脚本运行
-RapidTable类提供model_path参数,可以自行指定上述2个模型,默认是`slanet-plus.onnx`。举例如下:
+> ⚠️注意:在`rapid_table>=1.0.0`之后,模型输入均采用dataclasses封装,简化和兼容参数传递。输入和输出定义如下:
```python
-table_engine = RapidTable()
+# 输入
+@dataclass
+class RapidTableInput:
+ model_type: Optional[str] = ModelType.SLANETPLUS.value
+ model_path: Union[str, Path, None, Dict[str, str]] = None
+ use_cuda: bool = False
+ device: str = "cpu"
+
+# 输出
+@dataclass
+class RapidTableOutput:
+ pred_html: Optional[str] = None
+ cell_bboxes: Optional[np.ndarray] = None
+ logic_points: Optional[np.ndarray] = None
+ elapse: Optional[float] = None
+
+# 使用示例
+input_args = RapidTableInput(model_type="unitable")
+table_engine = RapidTable(input_args)
+
+img_path = 'test_images/table.jpg'
+table_results = table_engine(img_path)
+
+print(table_results.pred_html)
```
完整示例:
@@ -120,20 +111,28 @@ from rapid_table import RapidTable, VisTable
from rapidocr_onnxruntime import RapidOCR
from rapid_table.table_structure.utils import trans_char_ocr_res
-
+# 默认是slanet_plus模型
table_engine = RapidTable()
+
# 开启onnx-gpu推理
-# table_engine = RapidTable(use_cuda=True)
+# input_args = RapidTableInput(use_cuda=True)
+# table_engine = RapidTable(input_args)
+
+# 使用torch推理版本的unitable模型
+# input_args = RapidTableInput(model_type="unitable", use_cuda=True, device="cuda:0")
+# table_engine = RapidTable(input_args)
+
ocr_engine = RapidOCR()
viser = VisTable()
img_path = 'test_images/table.jpg'
-
ocr_result, _ = ocr_engine(img_path)
+
# 单字匹配
# ocr_result, _ = ocr_engine(img_path, return_word_box=True)
# ocr_result = trans_char_ocr_res(ocr_result)
-table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_result)
+
+table_results = table_engine(img_path, ocr_result)
save_dir = Path("./inference_results/")
save_dir.mkdir(parents=True, exist_ok=True)
@@ -141,50 +140,217 @@ save_dir.mkdir(parents=True, exist_ok=True)
save_html_path = save_dir / f"{Path(img_path).stem}.html"
save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
-viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path)
-
-# 返回逻辑坐标
-# table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, ocr_result, return_logic_points=True)
-# save_logic_path = save_dir / f"vis_logic_{Path(img_path).name}"
-# viser(img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path,logic_points, save_logic_path)
-
+viser(
+ img_path,
+ table_results.pred_html,
+ save_html_path,
+ table_results.cell_bboxes,
+ save_drawed_path,
+ table_results.logic_points,
+ save_logic_path,
+)
print(table_html_str)
```
#### 终端运行
-- 用法:
-
- ```bash
- $ rapid_table -h
- usage: rapid_table [-h] [-v] -img IMG_PATH [-m MODEL_PATH]
-
- optional arguments:
- -h, --help show this help message and exit
- -v, --vis Whether to visualize the layout results.
- -img IMG_PATH, --img_path IMG_PATH
- Path to image for layout.
- -m MODEL_PATH, --model_path MODEL_PATH
- The model path used for inference.
- ```
-
-- 示例:
-
- ```bash
- rapid_table -v -img test_images/table.jpg
- ```
+```bash
+rapid_table -v -img test_images/table.jpg
+```
### 结果
#### 返回结果
+
+
```html
-Methods | | | | FPS |
SegLink [26] | 70.0 | 86d> | 77.0 | 8.9 |
PixelLink [4] | 73.2 | 83.0 | 77.8 | |
TextSnake [18] | 73.9 | 83.2 | 78.3 | 1.1 |
TextField [37] | 75.9 | 87.4 | 81.3 | 5.2 |
MSR[38] | 76.7 | 87.87.4 | 81.7 | |
FTSN [3] | 77.1 | 87.6 | 82.0 | |
LSE[30] | 81.7 | 84.2 | 82.9 | <>
CRAFT [2] | 78.2 | 88.2 | 82.9 | 8.6 |
MCN[16] | 79 | 88 | 83 | |
ATRR>[35] | 82.1 | 85.2 | 83.6 | |
PAN [34] | 83.8 | 84.4 | 84.1 | 30.2 |
DB[12] | 79.2 | 91.5 | 84.9 | 32.0 |
DRRG[41] | 82.30 | 88.05 | 85.08 | |
Ours (SynText) | 80.68 | 8582.97 | 12.68 | |
Ours (MLT-17) | 84.54 | 86.62 | 85.57 | 12.31 |
+
+
+
+
+ Methods |
+ |
+ |
+ |
+ FPS |
+
+
+ SegLink [26] |
+ 70.0 |
+ 86d>
+ |
+ 77.0 |
+ 8.9 |
+
+
+ PixelLink [4] |
+ 73.2 |
+ 83.0 |
+ 77.8 |
+ |
+
+
+ TextSnake [18] |
+ 73.9 |
+ 83.2 |
+ 78.3 |
+ 1.1 |
+
+
+ TextField [37] |
+ 75.9 |
+ 87.4 |
+ 81.3 |
+ 5.2 |
+
+
+ MSR[38] |
+ 76.7 |
+ 87.87.4 |
+ 81.7 |
+ |
+
+
+ FTSN [3] |
+ 77.1 |
+ 87.6 |
+ 82.0 |
+ |
+
+
+ LSE[30] |
+ 81.7 |
+ 84.2 |
+ 82.9 |
+ <>
+
+
+
+ CRAFT [2] |
+ 78.2 |
+ 88.2 |
+ 82.9 |
+ 8.6 |
+
+
+ MCN[16] |
+ 79 |
+ 88 |
+ 83 |
+ |
+
+
+ ATRR
+ >[35] |
+ 82.1 |
+ 85.2 |
+ 83.6 |
+ |
+
+
+ PAN [34] |
+ 83.8 |
+ 84.4 |
+ 84.1 |
+ 30.2 |
+
+
+ DB[12] |
+ 79.2
+ | 91.5 |
+ 84.9 |
+ 32.0 |
+
+
+ DRRG[41] |
+ 82.30 |
+ 88.05 |
+ 85.08 |
+ |
+
+
+ Ours (SynText) |
+ 80.68 |
+ 85
+
+ 82.97 |
+ 12.68 |
+ |
+
+ Ours (MLT-17) |
+ 84.54 |
+ 86.62 |
+ 85.57 |
+ 12.31 |
+
+
+
+
```
+
+
#### 可视化结果
Methods | | | | FPS |
SegLink [26] | 70.0 | 86d> | 77.0 | 8.9 |
PixelLink [4] | 73.2 | 83.0 | 77.8 | |
TextSnake [18] | 73.9 | 83.2 | 78.3 | 1.1 |
TextField [37] | 75.9 | 87.4 | 81.3 | 5.2 |
MSR[38] | 76.7 | 87.87.4 | 81.7 | |
FTSN [3] | 77.1 | 87.6 | 82.0 | |
LSE[30] | 81.7 | 84.2 | 82.9 | <>
CRAFT [2] | 78.2 | 88.2 | 82.9 | 8.6 |
MCN[16] | 79 | 88 | 83 | |
ATRR>[35] | 82.1 | 85.2 | 83.6 | |
PAN [34] | 83.8 | 84.4 | 84.1 | 30.2 |
DB[12] | 79.2 | 91.5 | 84.9 | 32.0 |
DRRG[41] | 82.30 | 88.05 | 85.08 | |
Ours (SynText) | 80.68 | 8582.97 | 12.68 | |
Ours (MLT-17) | 84.54 | 86.62 | 85.57 | 12.31 |
+
+### 与[TableStructureRec](https://github.com/RapidAI/TableStructureRec)关系
+
+TableStructureRec库是一个表格识别算法的集合库,当前有`wired_table_rec`有线表格识别算法和`lineless_table_rec`无线表格识别算法的推理包。
+
+RapidTable是整理自PP-Structure中表格识别部分而来。由于PP-Structure较早,这个库命名就成了`rapid_table`。
+
+总之,RapidTable和TabelStructureRec都是表格识别的仓库。大家可以都试试,哪个好用用哪个。由于每个算法都不太同,暂时不打算做统一处理。
+
+关于表格识别算法的比较,可参见[TableStructureRec测评](https://github.com/RapidAI/TableStructureRec#指标结果)
+
+### 更新日志
+
+
+
+#### 2024.12.30 update
+
+- 支持Unitable模型的表格识别,使用pytorch框架
+
+#### 2024.11.24 update
+
+- 支持gpu推理,适配 rapidOCR 单字识别匹配,支持逻辑坐标返回及可视化
+
+#### 2024.10.13 update
+
+- 补充最新paddlex-SLANet-plus 模型(paddle2onnx原因暂不能支持onnx)
+
+#### 2023-12-29 v0.1.3 update
+
+- 优化可视化结果部分
+
+#### 2023-12-27 v0.1.2 update
+
+- 添加返回cell坐标框参数
+- 完善可视化函数
+
+#### 2023-07-17 v0.1.0 update
+
+- 将`rapidocr_onnxruntime`部分从`rapid_table`中解耦合出来,给出选项是否依赖,更加灵活。
+
+- 增加接口输入参数`ocr_result`:
+ - 如果在调用函数时,事先指定了`ocr_result`参数值,则不会再走OCR。其中`ocr_result`格式需要和`rapidocr_onnxruntime`返回值一致。
+ - 如果未指定`ocr_result`参数值,但是事先安装了`rapidocr_onnxruntime`库,则会自动调用该库,进行识别。
+ - 如果`ocr_result`未指定,且`rapidocr_onnxruntime`未安装,则会报错。必须满足两个条件中一个。
+
+#### 2023-07-10 v0.0.13 updata
+
+- 更改传入表格还原中OCR的实例接口,可以传入其他OCR实例,前提要与`rapidocr_onnxruntime`接口一致
+
+#### 2023-07-06 v0.0.12 update
+
+- 去掉返回表格的html字符串中的``元素,便于后续统一。
+- 采用Black工具优化代码
+
+
diff --git a/demo.py b/demo.py
index fb13d21..d9e8ba3 100644
--- a/demo.py
+++ b/demo.py
@@ -6,12 +6,14 @@
import cv2
from rapidocr_onnxruntime import RapidOCR, VisRes
-from rapid_table import RapidTable, VisTable
+from rapid_table import RapidTable, RapidTableInput, VisTable
# Init
ocr_engine = RapidOCR()
vis_ocr = VisRes()
-table_engine = RapidTable()
+
+input_args = RapidTableInput(model_type="unitable")
+table_engine = RapidTable(input_args)
viser = VisTable()
img_path = "tests/test_files/table.jpg"
@@ -21,7 +23,8 @@
boxes, txts, scores = list(zip(*ocr_result))
# Table Rec
-table_html_str, table_cell_bboxes, _ = table_engine(img_path, ocr_result)
+table_results = table_engine(img_path, ocr_result)
+table_html_str, table_cell_bboxes = table_results.pred_html, table_results.cell_bboxes
# Save
save_dir = Path("outputs")
@@ -31,9 +34,7 @@
save_drawed_path = save_dir / f"{Path(img_path).stem}_table_vis{Path(img_path).suffix}"
# Visualize table rec result
-vis_imged = viser(
- img_path, table_html_str, save_html_path, table_cell_bboxes, save_drawed_path
-)
+vis_imged = viser(img_path, table_results, save_html_path, save_drawed_path)
# Visualize OCR result
save_ocr_path = save_dir / f"{Path(img_path).stem}_ocr_vis{Path(img_path).suffix}"
diff --git a/outputs/table.html b/outputs/table.html
deleted file mode 100644
index 1d26243..0000000
--- a/outputs/table.html
+++ /dev/null
@@ -1,4 +0,0 @@
-Methods | R | P | F | FPS |
SegLink[26] | 70.0 | 86.0 | 77.0 | 8.9 |
PixelLink[4] | 73.2 | 83.0 | 77.8 | - |
TextSnake[18] | 73.9 | 83.2 | 78.3 | 1.1 |
TextField[37] | 75.9 | 87.4 | 81.3 | 5.2 |
MSR[38] | 76.7 | 87.4 | 81.7 | |
FTSN[3] | 77.1 | 87.6 | 82.0 | |
LSE[30] | 81.7 | 84.2 | 82.9 | " |
CRAFT[2] | 78.2 | 88.2 | 82.9 | 8.6 |
MCN[16] | 79 | 88 | 83 | - |
ATRR[35] | 82.1 | 85.2 | 83.6 | " |
PAN [34] | 83.8 | 84.4 | 84.1 | 30.2 |
DB[12] | 79.2 | 91.5 | 84.9 | 32.0 |
DRRG[41] | 82.30 | 88.05 | 85.08 | - |
Ours (SynText) | 80.68 | 85.40 | 82.97 | 12.68 |
Ours (MLT-17) | 84.54 | 86.62 | 85.57 | 12.31 |
\ No newline at end of file
diff --git a/outputs/table_ocr_vis.jpg b/outputs/table_ocr_vis.jpg
deleted file mode 100644
index f79d89e..0000000
Binary files a/outputs/table_ocr_vis.jpg and /dev/null differ
diff --git a/outputs/table_table_vis.jpg b/outputs/table_table_vis.jpg
deleted file mode 100644
index cfa68c9..0000000
Binary files a/outputs/table_table_vis.jpg and /dev/null differ
diff --git a/rapid_table/__init__.py b/rapid_table/__init__.py
index 93c6a5d..e152860 100644
--- a/rapid_table/__init__.py
+++ b/rapid_table/__init__.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
-from .main import RapidTable
-from .utils import VisTable
+from .main import RapidTable, RapidTableInput
+from .utils.utils import VisTable
diff --git a/rapid_table/main.py b/rapid_table/main.py
index d620fe5..426a7ce 100644
--- a/rapid_table/main.py
+++ b/rapid_table/main.py
@@ -5,33 +5,76 @@
import copy
import importlib
import time
+from dataclasses import asdict, dataclass
+from enum import Enum
from pathlib import Path
-from typing import List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
+from rapid_table.utils.download_model import DownloadModel
+from rapid_table.utils.logger import get_logger
+from rapid_table.utils.utils import LoadImage, VisTable
+
from .table_matcher import TableMatch
-from .table_structure import TableStructurer
-from .utils import LoadImage, VisTable
+from .table_structure import TableStructurer, TableStructureUnitable
+logger = get_logger("main")
root_dir = Path(__file__).resolve().parent
+class ModelType(Enum):
+ PPSTRUCTURE_EN = "ppstructure_en"
+ PPSTRUCTURE_ZH = "ppstructure_zh"
+ SLANETPLUS = "slanet_plus"
+ UNITABLE = "unitable"
+
+
+ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
+KEY_TO_MODEL_URL = {
+ ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
+ ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
+ ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
+ ModelType.UNITABLE.value: {
+ "encoder": f"{ROOT_URL}/unitable/encoder.pth",
+ "decoder": f"{ROOT_URL}/unitable/decoder.pth",
+ "vocab": f"{ROOT_URL}/unitable/vocab.json",
+ },
+}
+
+
+@dataclass
+class RapidTableInput:
+ model_type: Optional[str] = ModelType.SLANETPLUS.value
+ model_path: Union[str, Path, None, Dict[str, str]] = None
+ use_cuda: bool = False
+ device: str = "cpu"
+
+
+@dataclass
+class RapidTableOutput:
+ pred_html: Optional[str] = None
+ cell_bboxes: Optional[np.ndarray] = None
+ logic_points: Optional[np.ndarray] = None
+ elapse: Optional[float] = None
+
+
class RapidTable:
- def __init__(self, model_path: Optional[str] = None, model_type: str = None, use_cuda: bool = False):
- if model_path is None:
- model_path = str(
- root_dir / "models" / "slanet-plus.onnx"
+ def __init__(self, config: RapidTableInput):
+ self.model_type = config.model_type
+ if self.model_type not in KEY_TO_MODEL_URL:
+ model_list = ",".join(KEY_TO_MODEL_URL)
+ raise ValueError(
+ f"{self.model_type} is not supported. The currently supported models are {model_list}."
)
- model_type = "slanet-plus"
- self.model_type = model_type
- self.load_img = LoadImage()
- config = {
- "model_path": model_path,
- "use_cuda": use_cuda,
- }
- self.table_structure = TableStructurer(config)
+
+ config.model_path = self.get_model_path(config.model_type, config.model_path)
+ if self.model_type == ModelType.UNITABLE.value:
+ self.table_structure = TableStructureUnitable(asdict(config))
+ else:
+ self.table_structure = TableStructurer(asdict(config))
+
self.table_matcher = TableMatch()
try:
@@ -39,12 +82,13 @@ def __init__(self, model_path: Optional[str] = None, model_type: str = None, use
except ModuleNotFoundError:
self.ocr_engine = None
+ self.load_img = LoadImage()
+
def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
- return_logic_points = False
- ) -> Tuple[str, float]:
+ ) -> RapidTableOutput:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
@@ -52,25 +96,28 @@ def __call__(
img = self.load_img(img_content)
- s = time.time()
+ s = time.perf_counter()
h, w = img.shape[:2]
if ocr_result is None:
ocr_result, _ = self.ocr_engine(img)
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
- pred_structures, pred_bboxes, _ = self.table_structure(copy.deepcopy(img))
+ pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
+
# 适配slanet-plus模型输出的box缩放还原
- if self.model_type == "slanet-plus":
- pred_bboxes = self.adapt_slanet_plus(img, pred_bboxes)
- pred_html = self.table_matcher(pred_structures, pred_bboxes, dt_boxes, rec_res)
- # 避免低版本升级后出现问题,默认不打开
- if return_logic_points:
- logic_points = self.table_matcher.decode_logic_points(pred_structures)
- elapse = time.time() - s
- return pred_html, pred_bboxes, logic_points, elapse
- elapse = time.time() - s
- return pred_html, pred_bboxes, elapse
+ if self.model_type == ModelType.SLANETPLUS.value:
+ cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+
+ pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
+
+ # 过滤掉占位的bbox
+ mask = ~np.all(cell_bboxes == 0, axis=1)
+ cell_bboxes = cell_bboxes[mask]
+
+ logic_points = self.table_matcher.decode_logic_points(pred_structures)
+ elapse = time.perf_counter() - s
+ return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
def get_boxes_recs(
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
@@ -89,15 +136,39 @@ def get_boxes_recs(
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res
- def adapt_slanet_plus(self, img: np.ndarray, pred_bboxes: np.ndarray) -> np.ndarray:
+
+ def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
h, w = img.shape[:2]
resized = 488
ratio = min(resized / h, resized / w)
w_ratio = resized / (w * ratio)
h_ratio = resized / (h * ratio)
- pred_bboxes[:, 0::2] *= w_ratio
- pred_bboxes[:, 1::2] *= h_ratio
- return pred_bboxes
+ cell_bboxes[:, 0::2] *= w_ratio
+ cell_bboxes[:, 1::2] *= h_ratio
+ return cell_bboxes
+
+ @staticmethod
+ def get_model_path(
+ model_type: str, model_path: Union[str, Path, None]
+ ) -> Union[str, Dict[str, str]]:
+ if model_path is not None:
+ return model_path
+
+ model_url = KEY_TO_MODEL_URL.get(model_type, None)
+ if isinstance(model_url, str):
+ model_path = DownloadModel.download(model_url)
+ return model_path
+
+ if isinstance(model_url, dict):
+ model_paths = {}
+ for k, url in model_url.items():
+ model_paths[k] = DownloadModel.download(
+ url, save_model_name=f"{model_type}_{Path(url).name}"
+ )
+ return model_paths
+
+ raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
+
def main():
parser = argparse.ArgumentParser()
@@ -105,6 +176,7 @@ def main():
"-v",
"--vis",
action="store_true",
+ default=False,
help="Wheter to visualize the layout results.",
)
parser.add_argument(
@@ -112,10 +184,10 @@ def main():
)
parser.add_argument(
"-m",
- "--model_path",
+ "--model_type",
type=str,
- default=str(root_dir / "models" / "en_ppstructure_mobile_v2_SLANet.onnx"),
- help="The model path used for inference.",
+ default=ModelType.SLANETPLUS.value,
+ choices=list(KEY_TO_MODEL_URL),
)
args = parser.parse_args()
@@ -126,12 +198,16 @@ def main():
"Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime."
) from exc
- rapid_table = RapidTable(args.model_path)
+ table_engine = RapidTable(args.model_path)
img = cv2.imread(args.img_path)
ocr_result, _ = ocr_engine(img)
- table_html_str, table_cell_bboxes, elapse = rapid_table(img, ocr_result)
+ table_results = table_engine(img, ocr_result)
+ table_html_str, table_cell_bboxes = (
+ table_results.pred_html,
+ table_results.cell_bboxes,
+ )
print(table_html_str)
viser = VisTable()
diff --git a/rapid_table/table_matcher/matcher.py b/rapid_table/table_matcher/matcher.py
index 0453929..f579a12 100644
--- a/rapid_table/table_matcher/matcher.py
+++ b/rapid_table/table_matcher/matcher.py
@@ -22,18 +22,18 @@ def __init__(self, filter_ocr_result=True, use_master=False):
self.filter_ocr_result = filter_ocr_result
self.use_master = use_master
- def __call__(self, pred_structures, pred_bboxes, dt_boxes, rec_res):
+ def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
if self.filter_ocr_result:
- dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes, rec_res)
- matched_index = self.match_result(dt_boxes, pred_bboxes)
+ dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
+ matched_index = self.match_result(dt_boxes, cell_bboxes)
pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
return pred_html
- def match_result(self, dt_boxes, pred_bboxes):
+ def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
matched = {}
for i, gt_box in enumerate(dt_boxes):
distances = []
- for j, pred_box in enumerate(pred_bboxes):
+ for j, pred_box in enumerate(cell_bboxes):
if len(pred_box) == 8:
pred_box = [
np.min(pred_box[0::2]),
@@ -49,7 +49,11 @@ def match_result(self, dt_boxes, pred_bboxes):
sorted_distances = sorted(
sorted_distances, key=lambda item: (item[1], item[0])
)
- if distances.index(sorted_distances[0]) not in matched.keys():
+ # must > min_iou
+ if sorted_distances[0][1] >= 1 - min_iou:
+ continue
+
+ if distances.index(sorted_distances[0]) not in matched:
matched[distances.index(sorted_distances[0])] = [i]
else:
matched[distances.index(sorted_distances[0])].append(i)
@@ -111,6 +115,7 @@ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
filter_elements = ["", "", "", ""]
end_html = [v for v in end_html if v not in filter_elements]
return "".join(end_html), end_html
+
def decode_logic_points(self, pred_structures):
logic_points = []
current_row = 0
@@ -131,22 +136,24 @@ def mark_occupied(row, col, rowspan, colspan):
while i < len(pred_structures):
token = pred_structures[i]
- if token == '':
+ if token == "
":
current_col = 0 # 每次遇到
时,重置当前列号
- elif token == '
':
+ elif token == "":
current_row += 1 # 行结束,行号增加
- elif token .startswith(' | ':
+ if token != " | ":
j += 1
# 提取 colspan 和 rowspan 属性
- while j < len(pred_structures) and not pred_structures[j].startswith('>'):
- if 'colspan=' in pred_structures[j]:
- colspan = int(pred_structures[j].split('=')[1].strip('"\''))
- elif 'rowspan=' in pred_structures[j]:
- rowspan = int(pred_structures[j].split('=')[1].strip('"\''))
+ while j < len(pred_structures) and not pred_structures[
+ j
+ ].startswith(">"):
+ if "colspan=" in pred_structures[j]:
+ colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+ elif "rowspan=" in pred_structures[j]:
+ rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
j += 1
# 跳过已经处理过的属性 token
@@ -179,8 +186,8 @@ def mark_occupied(row, col, rowspan, colspan):
return logic_points
- def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
- y1 = pred_bboxes[:, 1::2].min()
+ def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
+ y1 = cell_bboxes[:, 1::2].min()
new_dt_boxes = []
new_rec_res = []
diff --git a/rapid_table/table_matcher/utils.py b/rapid_table/table_matcher/utils.py
index 3ec8fcc..57a613c 100644
--- a/rapid_table/table_matcher/utils.py
+++ b/rapid_table/table_matcher/utils.py
@@ -14,6 +14,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
+import copy
import re
@@ -48,7 +49,7 @@ def deal_isolate_span(thead_part):
spanStr_in_isolateItem = span_part.group()
# 3. merge the span number into the span token format string.
if spanStr_in_isolateItem is not None:
- corrected_item = " | ".format(spanStr_in_isolateItem)
+ corrected_item = f" | "
corrected_list.append(corrected_item)
else:
corrected_list.append(None)
@@ -243,6 +244,6 @@ def compute_iou(rec1, rec2):
# judge if there is an intersect
if left_line >= right_line or top_line >= bottom_line:
return 0.0
- else:
- intersect = (right_line - left_line) * (bottom_line - top_line)
- return (intersect / (sum_area - intersect)) * 1.0
+
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect)) * 1.0
diff --git a/rapid_table/table_structure/__init__.py b/rapid_table/table_structure/__init__.py
index a548391..3e638d9 100644
--- a/rapid_table/table_structure/__init__.py
+++ b/rapid_table/table_structure/__init__.py
@@ -12,3 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .table_structure import TableStructurer
+from .table_structure_unitable import TableStructureUnitable
diff --git a/rapid_table/table_structure/table_structure.py b/rapid_table/table_structure/table_structure.py
index c2509d8..9152603 100644
--- a/rapid_table/table_structure/table_structure.py
+++ b/rapid_table/table_structure/table_structure.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import time
-from typing import Dict, Any
+from typing import Any, Dict
import numpy as np
diff --git a/rapid_table/table_structure/table_structure_unitable.py b/rapid_table/table_structure/table_structure_unitable.py
new file mode 100644
index 0000000..2b98006
--- /dev/null
+++ b/rapid_table/table_structure/table_structure_unitable.py
@@ -0,0 +1,229 @@
+import re
+import time
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from tokenizers import Tokenizer
+from torchvision import transforms
+
+from .unitable_modules import Encoder, GPTFastDecoder
+
+IMG_SIZE = 448
+EOS_TOKEN = ""
+BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
+
+HTML_BBOX_HTML_TOKENS = [
+ " | ",
+ "[",
+ "] | ",
+ "[",
+ "> | ",
+ "",
+ "
",
+ "",
+ "",
+ "",
+ "",
+ ' rowspan="2"',
+ ' rowspan="3"',
+ ' rowspan="4"',
+ ' rowspan="5"',
+ ' rowspan="6"',
+ ' rowspan="7"',
+ ' rowspan="8"',
+ ' rowspan="9"',
+ ' rowspan="10"',
+ ' rowspan="11"',
+ ' rowspan="12"',
+ ' rowspan="13"',
+ ' rowspan="14"',
+ ' rowspan="15"',
+ ' rowspan="16"',
+ ' rowspan="17"',
+ ' rowspan="18"',
+ ' rowspan="19"',
+ ' colspan="2"',
+ ' colspan="3"',
+ ' colspan="4"',
+ ' colspan="5"',
+ ' colspan="6"',
+ ' colspan="7"',
+ ' colspan="8"',
+ ' colspan="9"',
+ ' colspan="10"',
+ ' colspan="11"',
+ ' colspan="12"',
+ ' colspan="13"',
+ ' colspan="14"',
+ ' colspan="15"',
+ ' colspan="16"',
+ ' colspan="17"',
+ ' colspan="18"',
+ ' colspan="19"',
+ ' colspan="25"',
+]
+
+VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
+TASK_TOKENS = [
+ "[table]",
+ "[html]",
+ "[cell]",
+ "[bbox]",
+ "[cell+bbox]",
+ "[html+bbox]",
+]
+
+
+class TableStructureUnitable:
+ def __init__(self, config):
+ # encoder_path: str, decoder_path: str, vocab_path: str, device: str
+ vocab_path = config["model_path"]["vocab"]
+ encoder_path = config["model_path"]["encoder"]
+ decoder_path = config["model_path"]["decoder"]
+ device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
+
+ self.vocab = Tokenizer.from_file(vocab_path)
+ self.token_white_list = [
+ self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
+ ]
+ self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
+ self.bbox_close_html_token = self.vocab.token_to_id("]")
+ self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
+ self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
+ self.max_seq_len = 1024
+ self.device = device
+ self.img_size = IMG_SIZE
+
+ # init encoder
+ encoder_state_dict = torch.load(encoder_path, map_location=device)
+ self.encoder = Encoder()
+ self.encoder.load_state_dict(encoder_state_dict)
+ self.encoder.eval().to(device)
+
+ # init decoder
+ decoder_state_dict = torch.load(decoder_path, map_location=device)
+ self.decoder = GPTFastDecoder()
+ self.decoder.load_state_dict(decoder_state_dict)
+ self.decoder.eval().to(device)
+
+ # define img transform
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((448, 448)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.86597056, 0.88463002, 0.87491087],
+ std=[0.20686628, 0.18201602, 0.18485524],
+ ),
+ ]
+ )
+
+ @torch.inference_mode()
+ def __call__(self, image: np.ndarray):
+ start_time = time.time()
+ ori_h, ori_w = image.shape[:2]
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = Image.fromarray(image)
+ image = self.transform(image).unsqueeze(0).to(self.device)
+ self.decoder.setup_caches(
+ max_batch_size=1,
+ max_seq_length=self.max_seq_len,
+ dtype=image.dtype,
+ device=self.device,
+ )
+ context = (
+ torch.tensor([self.prefix_token_id], dtype=torch.int32)
+ .repeat(1, 1)
+ .to(self.device)
+ )
+ eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device)
+ memory = self.encoder(image)
+ context = self.loop_decode(context, eos_id_tensor, memory)
+ bboxes, html_tokens = self.decode_tokens(context)
+ bboxes = bboxes.astype(np.float32)
+
+ # rescale boxes
+ scale_h = ori_h / self.img_size
+ scale_w = ori_w / self.img_size
+ bboxes[:, 0::2] *= scale_w # 缩放 x 坐标
+ bboxes[:, 1::2] *= scale_h # 缩放 y 坐标
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
+ structure_str_list = (
+ ["", "", ""]
+ + html_tokens
+ + ["
", "", ""]
+ )
+ return structure_str_list, bboxes, time.time() - start_time
+
+ def decode_tokens(self, context):
+ pred_html = context[0]
+ pred_html = pred_html.detach().cpu().numpy()
+ pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
+ seq = pred_html.split("")[0]
+ token_black_list = ["", "", *TASK_TOKENS]
+ for i in token_black_list:
+ seq = seq.replace(i, "")
+
+ tr_pattern = re.compile(r"(.*?)
", re.DOTALL)
+ td_pattern = re.compile(r"(.*?) | ", re.DOTALL)
+ bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
+
+ decoded_list = []
+ bbox_coords = []
+
+ # 查找所有的 标签
+ for tr_match in tr_pattern.finditer(pred_html):
+ tr_content = tr_match.group(1)
+ decoded_list.append("
")
+
+ # 查找所有的 标签
+ for td_match in td_pattern.finditer(tr_content):
+ td_attrs = td_match.group(1).strip()
+ td_content = td_match.group(2).strip()
+ if td_attrs:
+ decoded_list.append(" | ")
+ decoded_list.append(" | ")
+ else:
+ decoded_list.append(" | ")
+
+ # 查找 bbox 坐标
+ bbox_match = bbox_pattern.search(td_content)
+ if bbox_match:
+ xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
+ # 将坐标转换为从左上角开始顺时针到左下角的点的坐标
+ coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
+ bbox_coords.append(coords)
+ else:
+ # 填充占位的bbox,保证后续流程统一
+ bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
+ decoded_list.append("
")
+
+ bbox_coords_array = np.array(bbox_coords)
+ return bbox_coords_array, decoded_list
+
+ def loop_decode(self, context, eos_id_tensor, memory):
+ box_token_count = 0
+ for _ in range(self.max_seq_len):
+ eos_flag = (context == eos_id_tensor).any(dim=1)
+ if torch.all(eos_flag):
+ break
+
+ next_tokens = self.decoder(memory, context)
+ if next_tokens[0] in self.bbox_token_ids:
+ box_token_count += 1
+ if box_token_count > 4:
+ next_tokens = torch.tensor(
+ [self.bbox_close_html_token], dtype=torch.int32
+ )
+ box_token_count = 0
+ context = torch.cat([context, next_tokens], dim=1)
+ return context
diff --git a/rapid_table/table_structure/unitable_modules.py b/rapid_table/table_structure/unitable_modules.py
new file mode 100644
index 0000000..5b8dac3
--- /dev/null
+++ b/rapid_table/table_structure/unitable_modules.py
@@ -0,0 +1,911 @@
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.modules.transformer import _get_activation_fn
+
+TOKEN_WHITE_LIST = [
+ 1,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ 123,
+ 124,
+ 125,
+ 126,
+ 127,
+ 128,
+ 129,
+ 130,
+ 131,
+ 132,
+ 133,
+ 134,
+ 135,
+ 136,
+ 137,
+ 138,
+ 139,
+ 140,
+ 141,
+ 142,
+ 143,
+ 144,
+ 145,
+ 146,
+ 147,
+ 148,
+ 149,
+ 150,
+ 151,
+ 152,
+ 153,
+ 154,
+ 155,
+ 156,
+ 157,
+ 158,
+ 159,
+ 160,
+ 161,
+ 162,
+ 163,
+ 164,
+ 165,
+ 166,
+ 167,
+ 168,
+ 169,
+ 170,
+ 171,
+ 172,
+ 173,
+ 174,
+ 175,
+ 176,
+ 177,
+ 178,
+ 179,
+ 180,
+ 181,
+ 182,
+ 183,
+ 184,
+ 185,
+ 186,
+ 187,
+ 188,
+ 189,
+ 190,
+ 191,
+ 192,
+ 193,
+ 194,
+ 195,
+ 196,
+ 197,
+ 198,
+ 199,
+ 200,
+ 201,
+ 202,
+ 203,
+ 204,
+ 205,
+ 206,
+ 207,
+ 208,
+ 209,
+ 210,
+ 211,
+ 212,
+ 213,
+ 214,
+ 215,
+ 216,
+ 217,
+ 218,
+ 219,
+ 220,
+ 221,
+ 222,
+ 223,
+ 224,
+ 225,
+ 226,
+ 227,
+ 228,
+ 229,
+ 230,
+ 231,
+ 232,
+ 233,
+ 234,
+ 235,
+ 236,
+ 237,
+ 238,
+ 239,
+ 240,
+ 241,
+ 242,
+ 243,
+ 244,
+ 245,
+ 246,
+ 247,
+ 248,
+ 249,
+ 250,
+ 251,
+ 252,
+ 253,
+ 254,
+ 255,
+ 256,
+ 257,
+ 258,
+ 259,
+ 260,
+ 261,
+ 262,
+ 263,
+ 264,
+ 265,
+ 266,
+ 267,
+ 268,
+ 269,
+ 270,
+ 271,
+ 272,
+ 273,
+ 274,
+ 275,
+ 276,
+ 277,
+ 278,
+ 279,
+ 280,
+ 281,
+ 282,
+ 283,
+ 284,
+ 285,
+ 286,
+ 287,
+ 288,
+ 289,
+ 290,
+ 291,
+ 292,
+ 293,
+ 294,
+ 295,
+ 296,
+ 297,
+ 298,
+ 299,
+ 300,
+ 301,
+ 302,
+ 303,
+ 304,
+ 305,
+ 306,
+ 307,
+ 308,
+ 309,
+ 310,
+ 311,
+ 312,
+ 313,
+ 314,
+ 315,
+ 316,
+ 317,
+ 318,
+ 319,
+ 320,
+ 321,
+ 322,
+ 323,
+ 324,
+ 325,
+ 326,
+ 327,
+ 328,
+ 329,
+ 330,
+ 331,
+ 332,
+ 333,
+ 334,
+ 335,
+ 336,
+ 337,
+ 338,
+ 339,
+ 340,
+ 341,
+ 342,
+ 343,
+ 344,
+ 345,
+ 346,
+ 347,
+ 348,
+ 349,
+ 350,
+ 351,
+ 352,
+ 353,
+ 354,
+ 355,
+ 356,
+ 357,
+ 358,
+ 359,
+ 360,
+ 361,
+ 362,
+ 363,
+ 364,
+ 365,
+ 366,
+ 367,
+ 368,
+ 369,
+ 370,
+ 371,
+ 372,
+ 373,
+ 374,
+ 375,
+ 376,
+ 377,
+ 378,
+ 379,
+ 380,
+ 381,
+ 382,
+ 383,
+ 384,
+ 385,
+ 386,
+ 387,
+ 388,
+ 389,
+ 390,
+ 391,
+ 392,
+ 393,
+ 394,
+ 395,
+ 396,
+ 397,
+ 398,
+ 399,
+ 400,
+ 401,
+ 402,
+ 403,
+ 404,
+ 405,
+ 406,
+ 407,
+ 408,
+ 409,
+ 410,
+ 411,
+ 412,
+ 413,
+ 414,
+ 415,
+ 416,
+ 417,
+ 418,
+ 419,
+ 420,
+ 421,
+ 422,
+ 423,
+ 424,
+ 425,
+ 426,
+ 427,
+ 428,
+ 429,
+ 430,
+ 431,
+ 432,
+ 433,
+ 434,
+ 435,
+ 436,
+ 437,
+ 438,
+ 439,
+ 440,
+ 441,
+ 442,
+ 443,
+ 444,
+ 445,
+ 446,
+ 447,
+ 448,
+ 449,
+ 450,
+ 451,
+ 452,
+ 453,
+ 454,
+ 455,
+ 456,
+ 457,
+ 458,
+ 459,
+ 460,
+ 461,
+ 462,
+ 463,
+ 464,
+ 465,
+ 466,
+ 467,
+ 468,
+ 469,
+ 470,
+ 471,
+ 472,
+ 473,
+ 474,
+ 475,
+ 476,
+ 477,
+ 478,
+ 479,
+ 480,
+ 481,
+ 482,
+ 483,
+ 484,
+ 485,
+ 486,
+ 487,
+ 488,
+ 489,
+ 490,
+ 491,
+ 492,
+ 493,
+ 494,
+ 495,
+ 496,
+ 497,
+ 498,
+ 499,
+ 500,
+ 501,
+ 502,
+ 503,
+ 504,
+ 505,
+ 506,
+ 507,
+ 508,
+ 509,
+]
+
+
+class ImgLinearBackbone(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ patch_size: int,
+ in_chan: int = 3,
+ ) -> None:
+ super().__init__()
+
+ self.conv_proj = nn.Conv2d(
+ in_chan,
+ out_channels=d_model,
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ self.d_model = d_model
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.conv_proj(x)
+ x = x.flatten(start_dim=-2).transpose(1, 2)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.patch_size = 16
+ self.d_model = 768
+ self.dropout = 0
+ self.activation = "gelu"
+ self.norm_first = True
+ self.ff_ratio = 4
+ self.nhead = 12
+ self.max_seq_len = 1024
+ self.n_encoder_layer = 12
+ encoder_layer = nn.TransformerEncoderLayer(
+ self.d_model,
+ nhead=self.nhead,
+ dim_feedforward=self.ff_ratio * self.d_model,
+ dropout=self.dropout,
+ activation=self.activation,
+ batch_first=True,
+ norm_first=self.norm_first,
+ )
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.norm = norm_layer(self.d_model)
+ self.backbone = ImgLinearBackbone(
+ d_model=self.d_model, patch_size=self.patch_size
+ )
+ self.pos_embed = PositionEmbedding(
+ max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+ )
+ self.encoder = nn.TransformerEncoder(
+ encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ src_feature = self.backbone(x)
+ src_feature = self.pos_embed(src_feature)
+ memory = self.encoder(src_feature)
+ memory = self.norm(memory)
+ return memory
+
+
+class PositionEmbedding(nn.Module):
+ def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
+ super().__init__()
+ self.embedding = nn.Embedding(max_seq_len, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
+ # assume x is batch first
+ if input_pos is None:
+ _pos = torch.arange(x.shape[1], device=x.device)
+ else:
+ _pos = input_pos
+ out = self.embedding(_pos)
+ return self.dropout(out + x)
+
+
+class TokenEmbedding(nn.Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ d_model: int,
+ padding_idx: int,
+ ) -> None:
+ super().__init__()
+ assert vocab_size > 0
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.embedding(x)
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class ModelArgs:
+ n_layer: int = 4
+ n_head: int = 12
+ dim: int = 768
+ intermediate_size: int = None
+ head_dim: int = 64
+ activation: str = "gelu"
+ norm_first: bool = True
+
+ def __post_init__(self):
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self,
+ max_batch_size,
+ max_seq_length,
+ n_heads,
+ head_dim,
+ dtype=torch.bfloat16,
+ device="cpu",
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
+ self.register_buffer(
+ "k_cache",
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ persistent=False,
+ )
+ self.register_buffer(
+ "v_cache",
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ persistent=False,
+ )
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ # assert input_pos.shape[0] == k_val.shape[2]
+
+ bs = k_val.shape[0]
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:bs, :, input_pos] = k_val
+ v_out[:bs, :, input_pos] = v_val
+
+ return k_out[:bs], v_out[:bs]
+
+
+class GPTFastDecoder(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.vocab_size = 960
+ self.padding_idx = 2
+ self.prefix_token_id = 11
+ self.eos_id = 1
+ self.max_seq_len = 1024
+ self.dropout = 0
+ self.d_model = 768
+ self.nhead = 12
+ self.activation = "gelu"
+ self.norm_first = True
+ self.n_decoder_layer = 4
+ config = ModelArgs(
+ n_layer=self.n_decoder_layer,
+ n_head=self.nhead,
+ dim=self.d_model,
+ intermediate_size=self.d_model * 4,
+ activation=self.activation,
+ norm_first=self.norm_first,
+ )
+ self.config = config
+ self.layers = nn.ModuleList(
+ TransformerBlock(config) for _ in range(config.n_layer)
+ )
+ self.token_embed = TokenEmbedding(
+ vocab_size=self.vocab_size,
+ d_model=self.d_model,
+ padding_idx=self.padding_idx,
+ )
+ self.pos_embed = PositionEmbedding(
+ max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+ )
+ self.generator = nn.Linear(self.d_model, self.vocab_size)
+ self.token_white_list = TOKEN_WHITE_LIST
+ self.mask_cache: Optional[Tensor] = None
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+
+ def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
+ for b in self.layers:
+ b.multihead_attn.k_cache = None
+ b.multihead_attn.v_cache = None
+
+ if (
+ self.max_seq_length >= max_seq_length
+ and self.max_batch_size >= max_batch_size
+ ):
+ return
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.self_attn.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_length,
+ self.config.n_head,
+ head_dim,
+ dtype,
+ device,
+ )
+ b.multihead_attn.k_cache = None
+ b.multihead_attn.v_cache = None
+
+ self.causal_mask = torch.tril(
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+ ).to(device)
+
+ def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
+ input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
+ tgt = tgt[:, -1:]
+ tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
+ # tgt = self.decoder(tgt_feature, memory, input_pos)
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ ):
+ logits = tgt_feature
+ tgt_mask = self.causal_mask[None, None, input_pos]
+ for i, layer in enumerate(self.layers):
+ logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
+ # return output
+ logits = self.generator(logits)[:, -1, :]
+ total = set([i for i in range(logits.shape[-1])])
+ black_list = list(total.difference(set(self.token_white_list)))
+ logits[..., black_list] = -1e9
+ probs = F.softmax(logits, dim=-1)
+ _, next_tokens = probs.topk(1)
+ return next_tokens
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.self_attn = Attention(config)
+ self.multihead_attn = CrossAttention(config)
+
+ layer_norm_eps = 1e-5
+
+ d_model = config.dim
+ dim_feedforward = config.intermediate_size
+
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm_first = config.norm_first
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+
+ self.activation = _get_activation_fn(config.activation)
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Tensor,
+ input_pos: Tensor,
+ ) -> Tensor:
+ if self.norm_first:
+ x = tgt
+ x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
+ x = x + self.multihead_attn(self.norm2(x), memory)
+ x = x + self._ff_block(self.norm3(x))
+ else:
+ x = tgt
+ x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
+ x = self.norm2(x + self.multihead_attn(x, memory))
+ x = self.norm3(x + self._ff_block(x))
+ return x
+
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.activation(self.linear1(x)))
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(config.dim, 3 * config.dim)
+ self.wo = nn.Linear(config.dim, config.dim)
+
+ self.kv_cache: Optional[KVCache] = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.dim = config.dim
+
+ def forward(
+ self,
+ x: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_head * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ y = self.wo(y)
+ return y
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ self.query = nn.Linear(config.dim, config.dim)
+ self.key = nn.Linear(config.dim, config.dim)
+ self.value = nn.Linear(config.dim, config.dim)
+ self.out = nn.Linear(config.dim, config.dim)
+
+ self.k_cache = None
+ self.v_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+
+ def get_kv(self, xa: torch.Tensor):
+ if self.k_cache is not None and self.v_cache is not None:
+ return self.k_cache, self.v_cache
+
+ k = self.key(xa)
+ v = self.value(xa)
+
+ # Reshape for correct format
+ batch_size, source_seq_len, _ = k.shape
+ k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
+ v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
+
+ if self.k_cache is None:
+ self.k_cache = k
+ if self.v_cache is None:
+ self.v_cache = v
+
+ return k, v
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Tensor,
+ ):
+ q = self.query(x)
+ batch_size, target_seq_len, _ = q.shape
+ q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
+ k, v = self.get_kv(xa)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ wv = F.scaled_dot_product_attention(
+ query=q,
+ key=k,
+ value=v,
+ is_causal=False,
+ )
+ wv = wv.transpose(1, 2).reshape(
+ batch_size,
+ target_seq_len,
+ self.n_head * self.head_dim,
+ )
+
+ return self.out(wv)
diff --git a/rapid_table/table_structure/utils.py b/rapid_table/table_structure/utils.py
index c88cf3c..78f41d2 100644
--- a/rapid_table/table_structure/utils.py
+++ b/rapid_table/table_structure/utils.py
@@ -17,12 +17,11 @@
import os
import platform
import traceback
-
-import cv2
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
+import cv2
import numpy as np
from onnxruntime import (
GraphOptimizationLevel,
@@ -32,7 +31,7 @@
get_device,
)
-from rapid_table.table_structure.logger import get_logger
+from rapid_table.utils.logger import get_logger
class EP(Enum):
@@ -79,10 +78,12 @@ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
sess_opt.inter_op_num_threads = inter_op_num_threads
return sess_opt
+
def get_metadata(self, key: str = "character") -> list:
meta_dict = self.session.get_modelmeta().custom_metadata_map
content_list = meta_dict[key].splitlines()
return content_list
+
def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
cpu_provider_opts = {
"arena_extend_strategy": "kSameAsRequested",
diff --git a/rapid_table/utils/__init__.py b/rapid_table/utils/__init__.py
new file mode 100644
index 0000000..0ecdd4f
--- /dev/null
+++ b/rapid_table/utils/__init__.py
@@ -0,0 +1,3 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
diff --git a/rapid_table/utils/download_model.py b/rapid_table/utils/download_model.py
new file mode 100644
index 0000000..adedb5d
--- /dev/null
+++ b/rapid_table/utils/download_model.py
@@ -0,0 +1,67 @@
+import io
+from pathlib import Path
+from typing import Optional, Union
+
+import requests
+from tqdm import tqdm
+
+from .logger import get_logger
+
+logger = get_logger("DownloadModel")
+
+PROJECT_DIR = Path(__file__).resolve().parent.parent
+DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
+
+
+class DownloadModel:
+ @classmethod
+ def download(
+ cls,
+ model_full_url: Union[str, Path],
+ save_dir: Union[str, Path, None] = None,
+ save_model_name: Optional[str] = None,
+ ) -> str:
+ if save_dir is None:
+ save_dir = DEFAULT_MODEL_DIR
+
+ save_dir.mkdir(parents=True, exist_ok=True)
+
+ if save_model_name is None:
+ save_model_name = Path(model_full_url).name
+
+ save_file_path = save_dir / save_model_name
+ if save_file_path.exists():
+ logger.debug("%s already exists", save_file_path)
+ return str(save_file_path)
+
+ try:
+ logger.info("Download %s to %s", model_full_url, save_dir)
+ file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
+ cls.save_file(save_file_path, file)
+ except Exception as exc:
+ raise DownloadModelError from exc
+ return str(save_file_path)
+
+ @staticmethod
+ def download_as_bytes_with_progress(
+ url: Union[str, Path], name: Optional[str] = None
+ ) -> bytes:
+ resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
+ total = int(resp.headers.get("content-length", 0))
+ bio = io.BytesIO()
+ with tqdm(
+ desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
+ ) as pbar:
+ for chunk in resp.iter_content(chunk_size=65536):
+ pbar.update(len(chunk))
+ bio.write(chunk)
+ return bio.getvalue()
+
+ @staticmethod
+ def save_file(save_path: Union[str, Path], file: bytes):
+ with open(save_path, "wb") as f:
+ f.write(file)
+
+
+class DownloadModelError(Exception):
+ pass
diff --git a/rapid_table/table_structure/logger.py b/rapid_table/utils/logger.py
similarity index 100%
rename from rapid_table/table_structure/logger.py
rename to rapid_table/utils/logger.py
diff --git a/rapid_table/utils.py b/rapid_table/utils/utils.py
similarity index 85%
rename from rapid_table/utils.py
rename to rapid_table/utils/utils.py
index 350f2e5..ad35a19 100644
--- a/rapid_table/utils.py
+++ b/rapid_table/utils/utils.py
@@ -4,7 +4,7 @@
import os
from io import BytesIO
from pathlib import Path
-from typing import Optional, Union, List
+from typing import Optional, Union
import cv2
import numpy as np
@@ -14,9 +14,7 @@
class LoadImage:
- def __init__(
- self,
- ):
+ def __init__(self):
pass
def __call__(self, img: InputType) -> np.ndarray:
@@ -79,25 +77,22 @@ class LoadImageError(Exception):
class VisTable:
- def __init__(
- self,
- ):
+ def __init__(self):
self.load_img = LoadImage()
def __call__(
- self,
- img_path: Union[str, Path],
- table_html_str: str,
- save_html_path: Optional[str] = None,
- table_cell_bboxes: Optional[np.ndarray] = None,
- save_drawed_path: Optional[str] = None,
- logic_points: List[List[float]] = None,
- save_logic_path: Optional[str] = None,
- ) -> None:
+ self,
+ img_path: Union[str, Path],
+ table_results,
+ save_html_path: Optional[str] = None,
+ save_drawed_path: Optional[str] = None,
+ save_logic_path: Optional[str] = None,
+ ):
if save_html_path:
- html_with_border = self.insert_border_style(table_html_str)
+ html_with_border = self.insert_border_style(table_results.pred_html)
self.save_html(save_html_path, html_with_border)
+ table_cell_bboxes = table_results.cell_bboxes
if table_cell_bboxes is None:
return None
@@ -113,31 +108,37 @@ def __call__(
if save_drawed_path:
self.save_img(save_drawed_path, drawed_img)
+
if save_logic_path and logic_points:
- polygons = [[box[0],box[1], box[4], box[5]] for box in table_cell_bboxes]
- self.plot_rec_box_with_logic_info(img_path, save_logic_path, logic_points, polygons)
+ polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
+ self.plot_rec_box_with_logic_info(
+ img_path, save_logic_path, logic_points, polygons
+ )
return drawed_img
def insert_border_style(self, table_html_str: str):
- style_res = f""""""
+
prefix_table, suffix_table = table_html_str.split("")
html_with_border = f"{prefix_table}{style_res}{suffix_table}"
return html_with_border
- def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sorted_polygons):
+ def plot_rec_box_with_logic_info(
+ self, img_path, output_path, logic_points, sorted_polygons
+ ):
"""
:param img_path
:param output_path
@@ -182,7 +183,7 @@ def plot_rec_box_with_logic_info(self, img_path, output_path, logic_points, sort
)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 保存绘制后的图像
- cv2.imwrite(output_path, img)
+ self.save_img(output_path, img)
@staticmethod
def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
diff --git a/setup.py b/setup.py
index 0539473..caf2114 100644
--- a/setup.py
+++ b/setup.py
@@ -3,11 +3,18 @@
# @Contact: liekkaskono@163.com
import sys
from pathlib import Path
+from typing import List, Union
import setuptools
from get_pypi_latest_version import GetPyPiLatestVersion
+def read_txt(txt_path: Union[Path, str]) -> List[str]:
+ with open(txt_path, "r", encoding="utf-8") as f:
+ data = [v.rstrip("\n") for v in f]
+ return data
+
+
def get_readme():
root_dir = Path(__file__).resolve().parent
readme_path = str(root_dir / "docs" / "doc_whl_rapid_table.md")
@@ -19,7 +26,7 @@ def get_readme():
MODULE_NAME = "rapid_table"
obtainer = GetPyPiLatestVersion()
latest_version = obtainer(MODULE_NAME)
-VERSION_NUM = obtainer.version_add_one(latest_version)
+VERSION_NUM = obtainer.version_add_one(latest_version, add_patch=True)
if len(sys.argv) > 2:
match_str = " ".join(sys.argv[2:])
@@ -34,19 +41,13 @@ def get_readme():
platforms="Any",
long_description=get_readme(),
long_description_content_type="text/markdown",
- description="Tools for parsing table structures based ONNXRuntime.",
+ description="Table Recognition",
author="SWHL",
author_email="liekkaskono@163.com",
- url="https://github.com/RapidAI/RapidStructure",
+ url="https://github.com/RapidAI/RapidTable",
license="Apache-2.0",
include_package_data=True,
- install_requires=[
- "onnxruntime>=1.7.0",
- "PyYAML>=6.0",
- "opencv_python>=4.5.1.48",
- "numpy>=1.21.6",
- "Pillow",
- ],
+ install_requires=read_txt("requirements.txt"),
packages=[
MODULE_NAME,
f"{MODULE_NAME}.models",
@@ -66,4 +67,5 @@ def get_readme():
],
python_requires=">=3.6,<3.13",
entry_points={"console_scripts": [f"{MODULE_NAME}={MODULE_NAME}.main:main"]},
+ extras_require={"torch": ["torch", "torchvision", "tokenizers"]},
)
diff --git a/tests/test_main.py b/tests/test_main.py
new file mode 100644
index 0000000..1138ef1
--- /dev/null
+++ b/tests/test_main.py
@@ -0,0 +1,43 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import sys
+from pathlib import Path
+
+import pytest
+from rapidocr_onnxruntime import RapidOCR
+
+cur_dir = Path(__file__).resolve().parent
+root_dir = cur_dir.parent
+
+sys.path.append(str(root_dir))
+
+from rapid_table import RapidTable, RapidTableInput
+
+ocr_engine = RapidOCR()
+
+input_args = RapidTableInput()
+table_engine = RapidTable(input_args)
+
+test_file_dir = cur_dir / "test_files"
+img_path = str(test_file_dir / "table.jpg")
+
+
+@pytest.mark.parametrize("model_type", ["slanet_plus", "unitable"])
+def test_ocr_input(model_type):
+ ocr_res, _ = ocr_engine(img_path)
+
+ input_args = RapidTableInput(model_type=model_type)
+ table_engine = RapidTable(input_args)
+
+ table_results = table_engine(img_path, ocr_res)
+ assert table_results.pred_html.count("") == 16
+
+
+@pytest.mark.parametrize("model_type", ["slanet_plus", "unitable"])
+def test_input_ocr_none(model_type):
+ input_args = RapidTableInput(model_type=model_type)
+ table_engine = RapidTable(input_args)
+ table_results = table_engine(img_path)
+ assert table_results.pred_html.count("
") == 16
+ assert len(table_results.cell_bboxes) == len(table_results.logic_points)
diff --git a/tests/test_table.py b/tests/test_table.py
deleted file mode 100644
index 41a2060..0000000
--- a/tests/test_table.py
+++ /dev/null
@@ -1,35 +0,0 @@
-# -*- encoding: utf-8 -*-
-# @Author: SWHL
-# @Contact: liekkaskono@163.com
-import sys
-from pathlib import Path
-
-from rapidocr_onnxruntime import RapidOCR
-
-cur_dir = Path(__file__).resolve().parent
-root_dir = cur_dir.parent
-
-sys.path.append(str(root_dir))
-
-from rapid_table import RapidTable
-
-ocr_engine = RapidOCR()
-table_engine = RapidTable()
-
-test_file_dir = cur_dir / "test_files"
-img_path = str(test_file_dir / "table.jpg")
-
-
-def test_ocr_input():
- ocr_res, _ = ocr_engine(img_path)
- table_html_str, table_cell_bboxes, elapse = table_engine(img_path, ocr_res)
- assert table_html_str.count("
") == 16
-
-
-def test_input_ocr_none():
- table_html_str, table_cell_bboxes, elapse = table_engine(img_path)
- assert table_html_str.count("
") == 16
-
-def test_logic_points_out():
- table_html_str, table_cell_bboxes, logic_points, elapse = table_engine(img_path, return_logic_points=True)
- assert len(table_cell_bboxes) == len(logic_points)