diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml deleted file mode 100644 index 5758947..0000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yaml +++ /dev/null @@ -1,17 +0,0 @@ ---- -name: 🐞 Bug -about: Bug -title: 'Bug' -labels: 'Bug' -assignees: '' - ---- - -请提供下述完整信息以便快速定位问题 -(Please provide the following information to quickly locate the problem) -- **系统环境/System Environment**: -- **使用的是哪门语言的程序/Which programing language**: -- **使用当前库的版本/Use version**: -- **可复现问题的demo和文件/Demo of reproducible problems**: -- **完整报错/Complete Error Message**: -- **可能的解决方案/Possible solutions**: \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index 98e4f5e..0000000 --- a/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,11 +0,0 @@ -blank_issues_enabled: false -contact_links: - - name: ❓ Questions - url: https://github.com/RapidAI/TableStructureRec/discussions/categories/q-a - about: Please use the community forum for help and questions regarding ProcessLaTeXFormulaTools Docs - - name: 💡 Feature requests and ideas - url: https://github.com/RapidAI/TableStructureRec/discussions/new?category=feature-requests - about: Please vote for and post new feature ideas in the community forum - - name: 📖 Documentation - url: https://rapidai.github.io/TableStructureRec/docs/ - about: A great place to find instructions and answers about RapidOCR. \ No newline at end of file diff --git a/.github/workflows/rapid_table_det.yml b/.github/workflows/rapid_table_det.yml deleted file mode 100644 index 3caa7a0..0000000 --- a/.github/workflows/rapid_table_det.yml +++ /dev/null @@ -1,54 +0,0 @@ -name: Push rapid_table_det_v to pypi - -on: - push: - tags: - - rapid_table_det_v* - -jobs: - UnitTesting: - runs-on: ubuntu-latest - steps: - - name: Pull latest code - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - - name: Unit testings - run: | - pip install -r requirements.txt - pip install pytest - pytest tests/test_table_det.py - - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Run setup.py - run: | - pip install -r requirements.txt - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - python setup_rapid_table_det.py bdist_wheel "${{ github.ref_name }}" - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - packages_dir: dist/ diff --git a/.github/workflows/rapid_table_det_paddle.yml b/.github/workflows/rapid_table_det_paddle.yml deleted file mode 100644 index 2081b62..0000000 --- a/.github/workflows/rapid_table_det_paddle.yml +++ /dev/null @@ -1,74 +0,0 @@ -name: Push rapid_table_det_v to pypi - -on: - push: - tags: - - rapid_table_det_paddle_v* - -jobs: - UnitTesting: - runs-on: ubuntu-latest - steps: - - name: Pull latest code - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Display Python version - run: python -c "import sys; print(sys.version)" - - - name: Unit testings - run: | - pip install -r requirements.txt - pip install paddlepaddle-gpu - pip install pytest - - wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det_paddle.zip - wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det_paddle.zip - wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det_paddle.zip - unzip cls_det_paddle.zip - unzip obj_det_paddle.zip - unzip edge_det_paddle.zip - mv *.pd* rapid_table_det_paddle/models/ - - pytest tests/test_table_det_paddle.py - - GenerateWHL_PushPyPi: - needs: UnitTesting - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - - - name: Set up Python 3.10 - uses: actions/setup-python@v4 - with: - python-version: '3.10' - architecture: 'x64' - - - name: Run setup.py - run: | - pip install -r requirements.txt - pip install paddlepaddle-gpu - python -m pip install --upgrade pip - pip install wheel get_pypi_latest_version - - wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/cls_det_paddle.zip - wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/obj_det_paddle.zip - wget https://github.com/Joker1212/RapidTableDetection/releases/download/v0.0.0/edge_det_paddle.zip - unzip cls_det_paddle.zip - unzip obj_det_paddle.zip - unzip edge_det_paddle.zip - mv *.pd* rapid_table_det_paddle/models/ - - python setup_rapid_table_det_paddle.py bdist_wheel "${{ github.ref_name }}" - - - name: Publish distribution 📦 to PyPI - uses: pypa/gh-action-pypi-publish@v1.5.0 - with: - password: ${{ secrets.PYPI_API_TOKEN }} - packages_dir: dist/ diff --git a/.gitignore b/.gitignore index c106952..6f62b84 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,3 @@ /models/ /images/ /outputs/ -/rapid_table_det_paddle/models/*.pd* -/rapid_table_det_paddle/outputs/ -/rapid_table_det/outputs/ -/rapid_table_det/models/*onnx -/tools/.ipynb_checkpoints/ -/.ipynb_checkpoints/ diff --git a/.ipynb_checkpoints/onnx_transform-checkpoint.ipynb b/.ipynb_checkpoints/onnx_transform-checkpoint.ipynb deleted file mode 100644 index 141fdc6..0000000 --- a/.ipynb_checkpoints/onnx_transform-checkpoint.ipynb +++ /dev/null @@ -1,128 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "source": [ - "!pip install paddle2onnx onnxruntime onnxslim onnxruntime-tools onnx -i https://pypi.tuna.tsinghua.edu.cn/simple" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "collapsed": false, - "jupyter": { - "is_executing": true, - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "!paddle2onnx --model_dir rapid_table_det_paddle/models/obj_det --model_filename model.pdmodel --params_filename model.pdiparams --save_file rapid_table_det/models/obj_det.onnx --opset_version 16 --enable_onnx_checker True\n", - "!paddle2onnx --model_dir rapid_table_det_paddle/models/db_net --model_filename model.pdmodel --params_filename model.pdiparams --save_file rapid_table_det/models/edge_det.onnx --opset_version 16 --enable_onnx_checker True\n", - "!paddle2onnx --model_dir rapid_table_det_paddle/models/pplcnet --model_filename model.pdmodel --params_filename model.pdiparams --save_file rapid_table_det/models/cls_det.onnx --opset_version 16 --enable_onnx_checker True\n", - "\n", - "!onnxslim rapid_table_det/models/obj_det.onnx rapid_table_det/models/obj_det.onnx\n", - "!onnxslim rapid_table_det/models/edge_det.onnx rapid_table_det/models/edge_det.onnx\n", - "!onnxslim rapid_table_det/models/cls_det.onnx rapid_table_det/models/cls_det.onnx" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-15T12:12:35.265576Z", - "start_time": "2024-10-15T12:12:34.281134Z" - }, - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "from pathlib import Path\n", - "import onnx\n", - "from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process\n", - "def quantize_model(root_dir_str, model_dir, pre_fix):\n", - "\n", - " original_model_path = f\"{pre_fix}.onnx\"\n", - " # quantized_model_path = f\"{pre_fix}_quantized.onnx\"\n", - " quantized_model_path = original_model_path\n", - " original_model_path = f\"{root_dir_str}/{model_dir}/{original_model_path}\"\n", - " quantized_model_path = f\"{root_dir_str}/{model_dir}/{quantized_model_path}\"\n", - " quant_pre_process(original_model_path, quantized_model_path, auto_merge=True)\n", - " # 进行动态量化\n", - " quantize_dynamic(\n", - " model_input=quantized_model_path,\n", - " model_output=quantized_model_path,\n", - " weight_type=QuantType.QUInt8\n", - " )\n", - "\n", - " # 验证量化后的模型\n", - " quantized_model = onnx.load(quantized_model_path)\n", - " onnx.checker.check_model(quantized_model)\n", - " print(\"Quantized model is valid.\")" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - } - }, - "source": [ - "root_dir_str = \".\"\n", - "model_dir = f\"rapid_table_det/models\"\n", - "quantize_model(root_dir_str, model_dir, \"obj_det\")\n", - "quantize_model(root_dir_str, model_dir, \"edge_det\")\n", - "# quantize_model(root_dir_str, model_dir, \"cls_det\")" - ], - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "source": [], - "outputs": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 5c227d6..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,19 +0,0 @@ -repos: -- repo: https://gitee.com/SWHL/autoflake - rev: v2.1.1 - hooks: - - id: autoflake - args: - [ - "--recursive", - "--in-place", - "--remove-all-unused-imports", - "--remove-unused-variable", - "--ignore-init-module-imports", - ] - files: \.py$ -- repo: https://gitee.com/SWHL/black - rev: 23.1.0 - hooks: - - id: black - files: \.py$ \ No newline at end of file diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 261eeb9..0000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/README.md b/README.md index f528371..0538af2 100644 --- a/README.md +++ b/README.md @@ -1,198 +1,2 @@ -
-
-

📊RapidTableDetection

-
- - -SemVer2.0 - - GitHub - -[English](README_en.md) | 简体中文 -
- -### 最近更新 - -- **2024.10.15** - - 完成初版代码,包含目标检测,语义分割,角点方向识别三个模块 -- **2024.11.2** - - 补充新训练yolo11的目标检测模型和边缘检测模型 - - 增加自动下载,轻量化包体积 - - 补充onnx-gpu推理支持,给出benchmark测试结果 - - 补充在线示例使用 - -### 简介 - -💡✨ 强大且高效的表格检测,支持论文、期刊、杂志、发票、收据、签到单等各种表格。 - -🚀 支持来源于paddle和yolo的版本,默认模型组合单图 CPU 推理仅需 1.2 秒,onnx-GPU(V100) 最小组合仅需 0.4 秒,paddle-gpu版0.2s -🛠️ 支持三个模块自由组合,独立训练调优,提供 ONNX 转换脚本和微调训练方案。 - -🌟 whl 包轻松集成使用,为下游 OCR、表格识别和数据采集提供强力支撑。 - -📚参考项目 [百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) -的实现方案,补充大量真实场景数据再训练 -![img.png](readme_resource/structure.png) -👇🏻训练数据集在致谢, 作者天天上班摸鱼搞开源,希望大家点个⭐️支持一下 - -### 使用建议 - -📚 文档场景: 无透视旋转,只使用目标检测\ -📷 拍照场景小角度旋转(-90~90): 默认左上角,不使用角点方向识别\ -🔍 使用在线体验找到适合你场景的模型组合 - -### 在线体验 -[modelscope](https://www.modelscope.cn/studios/jockerK/RapidTableDetDemo) [huggingface](https://huggingface.co/spaces/Joker1212/RapidTableDetection) -### 效果展示 - -![res_show.jpg](readme_resource/res_show.jpg)![res_show2.jpg](readme_resource/res_show2.jpg) - -### 安装 - -🪜模型会自动下载,也可以自己去仓库下载 [modescope模型仓](https://www.modelscope.cn/models/jockerK/TableExtractor) - -``` python {linenos=table} -# 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple -pip install rapid-table-det -``` - -#### 参数说明 - -默认值 -use_cuda: False : 启用gpu加速推理 \ -obj_model_type="yolo_obj_det", \ -edge_model_type= "yolo_edge_det", \ -cls_model_type= "paddle_cls_det" - -由于onnx使用gpu加速效果有限,还是建议直接使用yolox或安装paddle来执行模型会快很多(有需要我再补充整体流程) -paddle的s模型由于量化导致反而速度降低和精度降低,但是模型大小减少很多 - -| `model_type` | 任务类型 | 训练来源 | 大小 | 单表格耗时(v100-16G,cuda12,cudnn9,ubuntu) | -|:---------------------|:-------|:-------------------------------------|:-------|:-------------------------------------| -| **yolo_obj_det** | 表格目标检测 | `yolo11-l` | `100m` | `cpu:570ms, gpu:400ms` | -| `paddle_obj_det` | 表格目标检测 | `paddle yoloe-plus-x` | `380m` | `cpu:1000ms, gpu:300ms` | -| `paddle_obj_det_s` | 表格目标检测 | `paddle yoloe-plus-x + quantization` | `95m` | `cpu:1200ms, gpu:1000ms` | -| **yolo_edge_det** | 语义分割 | `yolo11-l-segment` | `108m` | `cpu:570ms, gpu:200ms` | -| `yolo_edge_det_s` | 语义分割 | `yolo11-s-segment` | `11m` | `cpu:260ms, gpu:200ms` | -| `paddle_edge_det` | 语义分割 | `paddle-dbnet` | `99m` | `cpu:1200ms, gpu:120ms` | -| `paddle_edge_det_s` | 语义分割 | `paddle-dbnet + quantization` | `25m` | `cpu:860ms, gpu:760ms` | -| **paddle_cls_det** | 方向分类 | `paddle pplcnet` | `6.5m` | `cpu:70ms, gpu:60ms` | - - -执行参数 -det_accuracy=0.7, -use_obj_det=True, -use_edge_det=True, -use_cls_det=True, - -### 快速使用 - -``` python {linenos=table} -from rapid_table_det.inference import TableDetector - -img_path = f"tests/test_files/chip.jpg" -table_det = TableDetector() - -result, elapse = table_det(img_path) -obj_det_elapse, edge_elapse, rotate_det_elapse = elapse -print( - f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" -) -# 输出可视化 -# import os -# import cv2 -# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img -# -# img = img_loader(img_path) -# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) -# file_name_with_ext = os.path.basename(img_path) -# file_name, file_ext = os.path.splitext(file_name_with_ext) -# out_dir = "rapid_table_det/outputs" -# if not os.path.exists(out_dir): -# os.makedirs(out_dir) -# extract_img = img.copy() -# for i, res in enumerate(result): -# box = res["box"] -# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] -# # 带识别框和左上角方向位置 -# img = visuallize(img, box, lt, rt, rb, lb) -# # 透视变换提取表格图片 -# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) -# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) -# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) - -``` -### paddle版本使用 -必须下载模型,指定模型位置! -``` python {linenos=table} -# 建议使用清华源安装 https://pypi.tuna.tsinghua.edu.cn/simple -pip install rapid-table-det-paddle (默认安装gpu版本,可以自行覆盖安装cpu版本paddlepaddle) -``` -```python -from rapid_table_det_paddle.inference import TableDetector - -img_path = f"tests/test_files/chip.jpg" - -table_det = TableDetector( - obj_model_path="models/obj_det_paddle", - edge_model_path="models/edge_det_paddle", - cls_model_path="models/cls_det_paddle", - use_obj_det=True, - use_edge_det=True, - use_cls_det=True, -) -result, elapse = table_det(img_path) -obj_det_elapse, edge_elapse, rotate_det_elapse = elapse -print( - f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" -) -# 一张图片中可能有多个表格 -# img = img_loader(img_path) -# file_name_with_ext = os.path.basename(img_path) -# file_name, file_ext = os.path.splitext(file_name_with_ext) -# out_dir = "rapid_table_det_paddle/outputs" -# if not os.path.exists(out_dir): -# os.makedirs(out_dir) -# extract_img = img.copy() -# for i, res in enumerate(result): -# box = res["box"] -# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] -# # 带识别框和左上角方向位置 -# img = visuallize(img, box, lt, rt, rb, lb) -# # 透视变换提取表格图片 -# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) -# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) -# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) - -``` - -## FAQ (Frequently Asked Questions) - -1. **问:如何微调模型适应特定场景?** - - 答:直接参考这个项目,有非常详细的可视化操作步骤,数据集也在里面,可以得到paddle的推理模型 [百度表格检测大赛](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL), - - yolo11的训练使用官方脚本足够简单,按官方指导转换为coco格式训练即可 -2. **问:如何导出onnx** - - 答:paddle模型需要在本项目tools下,有onnx_transform.ipynb文件 - yolo11的话,直接参照官方的方式一行搞定 -3. **问:图片有扭曲可以修正吗?** - - 答:本项目只解决旋转和透视场景的表格提取,对于扭曲的场景,需要先进行扭曲修正 - -### 致谢 - -[百度表格检测大赛第2名方案](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) \ -[WTW 自然场景表格数据集](https://tianchi.aliyun.com/dataset/108587) \ -[FinTabNet PDF文档表格数据集](https://developer.ibm.com/exchanges/data/all/fintabnet/) \ -[TableBank 表格数据集](https://doc-analysis.github.io/tablebank-page/) \ -[TableGeneration 表格自动生成工具](https://github.com/WenmuZhou/TableGeneration) - -### 贡献指南 - -欢迎提交请求。对于重大更改,请先打开issue讨论您想要改变的内容。 - -有其他的好建议和集成场景,作者也会积极响应支持 - -### 开源许可证 - -该项目采用[Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE) -开源许可证。 - +# RapidTableDet +检测和提取各种场景图片中的表格区域,并纠正透视和旋转问题 diff --git a/README_en.md b/README_en.md deleted file mode 100644 index 1bf0be6..0000000 --- a/README_en.md +++ /dev/null @@ -1,195 +0,0 @@ -
-
-

📊RapidTableDetection

-
- - -SemVer2.0 - - GitHub -
- -### Recent Updates - -- **2024.10.15** - - Completed the initial version of the code, including three modules: object detection, semantic segmentation, and corner direction recognition. -- **2024.11.2** - - Added new YOLOv11 object detection models and edge detection models. - - Increased automatic downloading and reduced package size. - - Added ONNX-GPU inference support and provided benchmark test results. - - Added online example usage. - -### Introduction - -💡✨ RapidTableDetection is a powerful and efficient table detection system that supports various types of tables, including those in papers, journals, magazines, invoices, receipts, and sign-in sheets. - -🚀 It supports versions derived from PaddlePaddle and YOLO, with the default model combination requiring only 1.2 seconds for single-image CPU inference, and 0.4 seconds for the smallest ONNX-GPU (V100) combination, or 0.2 seconds for the PaddlePaddle-GPU version. - -🛠️ It supports free combination and independent training optimization of three modules, providing ONNX conversion scripts and fine-tuning training solutions. - -🌟 The whl package is easy to integrate and use, providing strong support for downstream OCR, table recognition, and data collection. - -Refer to the implementation solution of the [2nd place in the Baidu Table Detection Competition](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL), and retrain with a large amount of real-world scenario data. -![img.png](readme_resource/structure.png) \ -The training dataset is acknowledged. The author works on open-source projects during spare time, please support by giving a star. - - -### Usage Recommendations - -- Document scenarios: No perspective rotation, use only object detection. -- Photography scenarios with small angle rotation (-90~90): Default top-left corner, do not use corner direction recognition. -- Use the online experience to find the suitable model combination for your scenario. - -### Online Experience -[modelscope](https://www.modelscope.cn/studios/jockerK/RapidTableDetDemo) [huggingface](https://huggingface.co/spaces/Joker1212/RapidTableDetection) - -### Effect Demonstration - -![res_show.jpg](readme_resource/res_show.jpg)![res_show2.jpg](readme_resource/res_show2.jpg) - -### Installation - -Models will be automatically downloaded, or you can download them from the repository [modelscope model warehouse](https://www.modelscope.cn/models/jockerK/TableExtractor). - -``` python {linenos=table} -pip install rapid-table-det -``` - -#### Parameter Explanation - -Default values: -- `use_cuda: False`: Enable GPU acceleration for inference. -- `obj_model_type="yolo_obj_det"`: Object detection model type. -- `edge_model_type="yolo_edge_det"`: Edge detection model type. -- `cls_model_type="paddle_cls_det"`: Corner direction classification model type. - - -Since ONNX has limited GPU acceleration, it is still recommended to directly use YOLOX or install PaddlePaddle for faster model execution (I can provide the entire process if needed). -The PaddlePaddle S model, due to quantization, actually slows down and reduces accuracy, but significantly reduces model size. - - -| `model_type` | Task Type | Training Source | Size | Single Table Inference Time (V100-16G, cuda12, cudnn9, ubuntu) | -|:---------------------|:---------|:-------------------------------------|:-------|:-------------------------------------| -| **yolo_obj_det** | Table Object Detection | `yolo11-l` | `100m` | `cpu:570ms, gpu:400ms` | -| `paddle_obj_det` | Table Object Detection | `paddle yoloe-plus-x` | `380m` | `cpu:1000ms, gpu:300ms` | -| `paddle_obj_det_s` | Table Object Detection | `paddle yoloe-plus-x + quantization` | `95m` | `cpu:1200ms, gpu:1000ms` | -| **yolo_edge_det** | Semantic Segmentation | `yolo11-l-segment` | `108m` | `cpu:570ms, gpu:200ms` | -| `yolo_edge_det_s` | Semantic Segmentation | `yolo11-s-segment` | `11m` | `cpu:260ms, gpu:200ms` | -| `paddle_edge_det` | Semantic Segmentation | `paddle-dbnet` | `99m` | `cpu:1200ms, gpu:120ms` | -| `paddle_edge_det_s` | Semantic Segmentation | `paddle-dbnet + quantization` | `25m` | `cpu:860ms, gpu:760ms` | -| **paddle_cls_det** | Direction Classification | `paddle pplcnet` | `6.5m` | `cpu:70ms, gpu:60ms` | - -Execution parameters: -- `det_accuracy=0.7` -- `use_obj_det=True` -- `use_edge_det=True` -- `use_cls_det=True` - -### Quick Start - -``` python {linenos=table} -from rapid_table_det.inference import TableDetector - -img_path = f"tests/test_files/chip.jpg" -table_det = TableDetector() - -result, elapse = table_det(img_path) -obj_det_elapse, edge_elapse, rotate_det_elapse = elapse -print( - f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" -) -# Output visualization -# import os -# import cv2 -# from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img -# -# img = img_loader(img_path) -# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) -# file_name_with_ext = os.path.basename(img_path) -# file_name, file_ext = os.path.splitext(file_name_with_ext) -# out_dir = "rapid_table_det/outputs" -# if not os.path.exists(out_dir): -# os.makedirs(out_dir) -# extract_img = img.copy() -# for i, res in enumerate(result): -# box = res["box"] -# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] -# # With detection box and top-left corner position -# img = visuallize(img, box, lt, rt, rb, lb) -# # Perspective transformation to extract table image -# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) -# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) -# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) - -``` -### Using PaddlePaddle Version -You must download the models and specify their locations! -``` python {linenos=table} -#(default installation is GPU version, you can override with CPU version paddlepaddle) -pip install rapid-table-det-paddle -``` -```python -from rapid_table_det_paddle.inference import TableDetector - -img_path = f"tests/test_files/chip.jpg" - -table_det = TableDetector( - obj_model_path="models/obj_det_paddle", - edge_model_path="models/edge_det_paddle", - cls_model_path="models/cls_det_paddle", - use_obj_det=True, - use_edge_det=True, - use_cls_det=True, -) -result, elapse = table_det(img_path) -obj_det_elapse, edge_elapse, rotate_det_elapse = elapse -print( - f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" -) -# more than one table in one image -# img = img_loader(img_path) -# file_name_with_ext = os.path.basename(img_path) -# file_name, file_ext = os.path.splitext(file_name_with_ext) -# out_dir = "rapid_table_det_paddle/outputs" -# if not os.path.exists(out_dir): -# os.makedirs(out_dir) -# extract_img = img.copy() -# for i, res in enumerate(result): -# box = res["box"] -# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] -# # With detection box and top-left corner position -# img = visuallize(img, box, lt, rt, rb, lb) -# # Perspective transformation to extract table image -# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) -# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) -# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) - -``` - -## FAQ (Frequently Asked Questions) - -1. **Q: How to fine-tune the model for specific scenarios?** - - A: Refer to this project, which provides detailed visualization steps and datasets. You can get the PaddlePaddle inference model from [Baidu Table Detection Competition](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL). For YOLOv11, use the official script, which is simple enough, and convert the data to COCO format for training as per the official guidelines. -2. **Q: How to export ONNX?** - - A: For PaddlePaddle models, use the `onnx_transform.ipynb` file in the `tools` directory of this project. For YOLOv11, follow the official method, which can be done in one line. -3. **Q: Can distorted images be corrected?** - - A: This project only handles rotation and perspective scenarios for table extraction. For distorted images, you need to correct the distortion first. - -### Acknowledgments - -- [2nd Place Solution in Baidu Table Detection Competition](https://aistudio.baidu.com/projectdetail/5398861?searchKeyword=%E8%A1%A8%E6%A0%BC%E6%A3%80%E6%B5%8B%E5%A4%A7%E8%B5%9B&searchTab=ALL) -- [WTW Natural Scene Table Dataset](https://tianchi.aliyun.com/dataset/108587) -- [FinTabNet PDF Document Table Dataset](https://developer.ibm.com/exchanges/data/all/fintabnet/) -- [TableBank Table Dataset](https://doc-analysis.github.io/tablebank-page/) -- [TableGeneration Table Auto-Generation Tool](https://github.com/WenmuZhou/TableGeneration) - -### Contribution Guidelines - -Pull requests are welcome. For major changes, please open an issue to discuss what you would like to change. - -If you have other good suggestions and integration scenarios, the author will actively respond and support them. - -### Open Source License - -This project is licensed under the [Apache 2.0](https://github.com/RapidAI/TableStructureRec/blob/c41bbd23898cb27a957ed962b0ffee3c74dfeff1/LICENSE) open source license. - diff --git a/demo_onnx.py b/demo_onnx.py deleted file mode 100644 index b5539d5..0000000 --- a/demo_onnx.py +++ /dev/null @@ -1,33 +0,0 @@ -from rapid_table_det.inference import TableDetector - -img_path = f"images/0c35d6430193babb29c6a94711742531-1_rot2_noise.jpg" -table_det = TableDetector( - edge_model_type="yolo_edge_det", obj_model_type="yolo_obj_det" -) - -result, elapse = table_det(img_path) -obj_det_elapse, edge_elapse, rotate_det_elapse = elapse -print( - f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" -) -# 输出可视化 -import os -import cv2 -from rapid_table_det.utils.visuallize import img_loader, visuallize, extract_table_img - -img = img_loader(img_path) -file_name_with_ext = os.path.basename(img_path) -file_name, file_ext = os.path.splitext(file_name_with_ext) -out_dir = "rapid_table_det/outputs" -if not os.path.exists(out_dir): - os.makedirs(out_dir) -extract_img = img.copy() -for i, res in enumerate(result): - box = res["box"] - lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] - # 带识别框和左上角方向位置 - img = visuallize(img, box, lt, rt, rb, lb) - # 透视变换提取表格图片 - wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) - cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) -cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) diff --git a/demo_paddle.py b/demo_paddle.py deleted file mode 100644 index 8be2a48..0000000 --- a/demo_paddle.py +++ /dev/null @@ -1,34 +0,0 @@ -from rapid_table_det_paddle.inference import TableDetector - -img_path = f"tests/test_files/chip.jpg" - -table_det = TableDetector( - obj_model_path="rapid_table_det_paddle/models/obj_det_paddle", - edge_model_path="rapid_table_det_paddle/models/edge_det_paddle", - cls_model_path="rapid_table_det_paddle/models/cls_det_paddle", - use_obj_det=True, - use_edge_det=True, - use_cls_det=True, -) -result, elapse = table_det(img_path) -obj_det_elapse, edge_elapse, rotate_det_elapse = elapse -print( - f"obj_det_elapse:{obj_det_elapse}, edge_elapse={edge_elapse}, rotate_det_elapse={rotate_det_elapse}" -) -# 一张图片中可能有多个表格 -# img = img_loader(img_path) -# file_name_with_ext = os.path.basename(img_path) -# file_name, file_ext = os.path.splitext(file_name_with_ext) -# out_dir = "rapid_table_det_paddle/outputs" -# if not os.path.exists(out_dir): -# os.makedirs(out_dir) -# extract_img = img.copy() -# for i, res in enumerate(result): -# box = res["box"] -# lt, rt, rb, lb = res["lt"], res["rt"], res["rb"], res["lb"] -# # 带识别框和左上角方向位置 -# img = visuallize(img, box, lt, rt, rb, lb) -# # 透视变换提取表格图片 -# wrapped_img = extract_table_img(extract_img.copy(), lt, rt, rb, lb) -# cv2.imwrite(f"{out_dir}/{file_name}-extract-{i}.jpg", wrapped_img) -# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img) diff --git a/inference.py b/inference.py new file mode 100644 index 0000000..11053ed --- /dev/null +++ b/inference.py @@ -0,0 +1,100 @@ +import cv2 +import numpy as np +from paddle_predictor import DbNet, ObjectDetector, PPLCNet +from utils import LoadImage + + +class TableDetector: + def __init__(self, dbnet_model_path, obj_model_path,pplcnet_model_path, **kwargs): + self.use_obj_det = kwargs.get("use_obj_det") + self.use_edge_det = kwargs.get("use_edge_det") + self.use_rotate_det = kwargs.get("use_rotate_det") + self.img_loader = LoadImage() + if self.use_obj_det: + self.obj_detector = ObjectDetector(obj_model_path) + if self.use_edge_det: + self.dbnet = DbNet(dbnet_model_path) + if self.use_rotate_det: + self.pplcnet = PPLCNet(pplcnet_model_path) + def __call__(self, img, det_accuracy=0.4): + img = self.img_loader(img) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_mask = img.copy() + h, w = img.shape[:-1] + img_box = np.array([1, 1, w - 1, h - 1]) + x1, y1, x2, y2 = img_box + lt = np.array([x1, y1]) # 左上角 + lb = np.array([x1, y2]) # 左下角 + rt = np.array([x2, y1]) # 右上角 + rb = np.array([x2, y2]) # 右下角 + obj_det_res = [[1.0, img_box]] + edge_box = img_box + pred_label = 0 + result = [] + if self.use_obj_det: + obj_det_res = self.obj_detector(img, score=det_accuracy) + for i in range(len(obj_det_res)): + det_res = obj_det_res[i] + score, box = det_res + xmin, ymin, xmax, ymax = box + if self.use_edge_det: + xmin_edge, ymin_edge, xmax_edge, ymax_edge = self.pad_box_points(h, w, xmax, xmin, ymax, ymin, 10) + crop_img = img_mask[ymin_edge:ymax_edge, xmin_edge:xmax_edge, :] + edge_box, lt, lb, rt, rb = self.dbnet(crop_img) + edge_box[:, 0] += xmin_edge + edge_box[:, 1] += ymin_edge + lt, lb, rt, rb = lt + [xmin_edge, ymin_edge], lb + [xmin_edge, ymin_edge], rt + [xmin_edge, ymin_edge], rb + [xmin_edge, ymin_edge] + if self.use_rotate_det: + xmin_cls, ymin_cls, xmax_cls, ymax_cls = self.pad_box_points(h, w, xmax, xmin, ymax, ymin, 10) + cls_box = edge_box.copy() + cls_img = img_mask[ymin_cls:ymax_cls, xmin_cls:xmax_cls, :] + cls_box[:, 0] = cls_box[:, 0] - xmin_cls + cls_box[:, 1] = cls_box[:, 1] - ymin_cls + # 画框增加先验信息,辅助方向label识别 + cv2.polylines(cls_img, [np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))], True, + color=(255, 0, 255), thickness=5) + pred_label = self.pplcnet(cls_img) + lb1, lt1, rb1, rt1 = self.get_real_rotated_points(lb, lt, pred_label, rb, rt) + result.append({ + "box": [int(xmin), int(ymin), int(xmax), int(ymax)], + "lb": [int(lb1[0]), int(lb1[1])], + "lt": [int(lt1[0]), int(lt1[1])], + "rt": [int(rt1[0]), int(rt1[1])], + "rb": [int(rb1[0]), int(rb1[1])], + }) + return result + + def get_real_rotated_points(self, lb, lt, pred_label, rb, rt): + if pred_label == 0: + lt1 = lt + rt1 = rt + rb1 = rb + lb1 = lb + elif pred_label == 1: + lt1 = rt + rt1 = rb + rb1 = lb + lb1 = lt + elif pred_label == 2: + lt1 = rb + rt1 = lb + rb1 = lt + lb1 = rt + elif pred_label == 3: + lt1 = lb + rt1 = lt + rb1 = rt + lb1 = rb + else: + lt1 = lt + rt1 = rt + rb1 = rb + lb1 = lb + return lb1, lt1, rb1, rt1 + + def pad_box_points(self, h, w, xmax, xmin, ymax, ymin, pad): + ymin_edge = max(ymin - pad, 0) + xmin_edge = max(xmin - pad, 0) + ymax_edge = min(ymax + pad, h) + xmax_edge = min(xmax + pad, w) + return xmin_edge, ymin_edge, xmax_edge, ymax_edge \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..ac5deb6 --- /dev/null +++ b/main.py @@ -0,0 +1,707 @@ +# 代码示例 +# python predict.py [src_image_dir] [results] + +import os +import sys +import glob +import json +import cv2 +import paddle +import math +import itertools +import numpy as np +from PIL import Image +from utils import * +MODEL_STAGES_PATTERN = { + "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] +} + + + + + + +# def generate_scale(im, resize_shape, keep_ratio): +# """ +# Args: +# im (np.ndarray): image (np.ndarray) +# Returns: +# im_scale_x: the resize ratio of X +# im_scale_y: the resize ratio of Y +# """ +# target_size = (resize_shape[0], resize_shape[1]) +# # target_size = (800, 1333) +# origin_shape = im.shape[:2] +# +# if keep_ratio: +# im_size_min = np.min(origin_shape) +# im_size_max = np.max(origin_shape) +# target_size_min = np.min(target_size) +# target_size_max = np.max(target_size) +# im_scale = float(target_size_min) / float(im_size_min) +# if np.round(im_scale * im_size_max) > target_size_max: +# im_scale = float(target_size_max) / float(im_size_max) +# im_scale_x = im_scale +# im_scale_y = im_scale +# else: +# resize_h, resize_w = target_size +# im_scale_y = resize_h / float(origin_shape[0]) +# im_scale_x = resize_w / float(origin_shape[1]) +# return im_scale_y, im_scale_x + + +# def normalize_img(im): +# is_scale = True +# im = im.astype(np.float32, copy=False) +# mean = [0.485, 0.456, 0.406] +# std = [0.229, 0.224, 0.225] +# norm_type = 'mean_std' +# if is_scale: +# scale = 1.0 / 255.0 +# im *= scale +# if norm_type == 'mean_std': +# mean = np.array(mean)[np.newaxis, np.newaxis, :] +# std = np.array(std)[np.newaxis, np.newaxis, :] +# im -= mean +# im /= std +# return im + + +# def normalize_img1(im): +# is_scale = True +# im = im.astype(np.float32, copy=False) +# mean = [0.485, 0.456, 0.406] +# std = [0.229, 0.224, 0.225] +# norm_type = 'none' +# if is_scale: +# scale = 1.0 / 255.0 +# if norm_type == 'mean_std': +# mean = np.array(mean)[np.newaxis, np.newaxis, :] +# std = np.array(std)[np.newaxis, np.newaxis, :] +# im -= mean +# im /= std +# return im + + +# def resize(im, im_info, resize_shape, keep_ratio, interp=2): +# im_scale_y, im_scale_x = generate_scale(im, resize_shape, keep_ratio) +# im = cv2.resize( +# im, +# None, +# None, +# fx=im_scale_x, +# fy=im_scale_y, +# interpolation=interp) +# im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') +# im_info['scale_factor'] = np.array( +# [im_scale_y, im_scale_x]).astype('float32') +# +# return im, im_info +# +# +# def pad(im, im_info, resize_shape): +# im_h, im_w = im.shape[:2] +# fill_value = [114.0, 114.0, 114.0] +# h, w = resize_shape[0], resize_shape[1] +# if h == im_h and w == im_w: +# im = im.astype(np.float32) +# return im, im_info +# +# canvas = np.ones((h, w, 3), dtype=np.float32) +# canvas *= np.array(fill_value, dtype=np.float32) +# canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) +# im = canvas +# return im, im_info + + +def nchoosek(startnum, endnum, step=1, n=1): + c = [] + for i in itertools.combinations(range(startnum, endnum + 1, step), n): + c.append(list(i)) + return c + + +def get_inv(concat): + a = concat[0][0] + b = concat[0][1] + c = concat[1][0] + d = concat[1][1] + det_concat = a * d - b * c + inv_result = np.array([[d / det_concat, -b / det_concat], + [-c / det_concat, a / det_concat]]) + return inv_result + + +def minboundquad(hull): + len_hull = len(hull) + xy = np.array(hull).reshape([-1, 2]) + idx = np.arange(0, len_hull) + idx_roll = np.roll(idx, -1, axis=0) + edges = np.array([idx, idx_roll]).reshape([2, -1]) + edges = np.transpose(edges, [1, 0]) + edgeangles1 = [] + for i in range(len_hull): + y = xy[edges[i, 1], 1] - xy[edges[i, 0], 1] + x = xy[edges[i, 1], 0] - xy[edges[i, 0], 0] + angle = math.atan2(y, x) + if angle < 0: + angle = angle + 2 * math.pi + edgeangles1.append([angle, i]) + edgeangles1_idx = sorted(list(edgeangles1), key=lambda x: x[0]) + edges1 = [] + edgeangle1 = [] + for item in edgeangles1_idx: + idx = item[1] + edges1.append(edges[idx, :]) + edgeangle1.append(item[0]) + edgeangles = np.array(edgeangle1) + edges = np.array(edges1) + eps = 2.2204e-16 + angletol = eps * 100 + + k = np.diff(edgeangles) < angletol + idx = np.where(k == 1) + edges = np.delete(edges, idx, 0) + edgeangles = np.delete(edgeangles, idx, 0) + nedges = edges.shape[0] + edgelist = np.array(nchoosek(0, nedges - 1, 1, 4)) + k = edgeangles[edgelist[:, 3]] - edgeangles[edgelist[:, 0]] <= math.pi + k_idx = np.where(k == 1) + edgelist = np.delete(edgelist, k_idx, 0) + + nquads = edgelist.shape[0] + quadareas = math.inf + qxi = np.zeros([5]) + qyi = np.zeros([5]) + cnt = np.zeros([4, 1, 2]) + edgelist = list(edgelist) + edges = list(edges) + xy = list(xy) + + for i in range(nquads): + edgeind = list(edgelist[i]) + edgeind.append(edgelist[i][0]) + edgesi = [] + edgeang = [] + for idx in edgeind: + edgesi.append(edges[idx]) + edgeang.append(edgeangles[idx]) + is_continue = False + for idx in range(len(edgeang) - 1): + diff = edgeang[idx + 1] - edgeang[idx] + if diff > math.pi: + is_continue = True + if is_continue: + continue + for j in range(4): + jplus1 = j + 1 + shared = np.intersect1d(edgesi[j], edgesi[jplus1]) + if shared.size != 0: + qxi[j] = xy[shared[0]][0] + qyi[j] = xy[shared[0]][1] + else: + A = xy[edgesi[j][0]] + B = xy[edgesi[j][1]] + C = xy[edgesi[jplus1][0]] + D = xy[edgesi[jplus1][1]] + concat = np.hstack(((A - B).reshape([2, -1]), (D - C).reshape([2, -1]))) + div = (A - C).reshape([2, -1]) + inv_result = get_inv(concat) + a = inv_result[0, 0] + b = inv_result[0, 1] + c = inv_result[1, 0] + d = inv_result[1, 1] + e = div[0, 0] + f = div[1, 0] + ts1 = [a * e + b * f, c * e + d * f] + Q = A + (B - A) * ts1[0] + qxi[j] = Q[0] + qyi[j] = Q[1] + + contour = np.array([qxi[:4], qyi[:4]]).astype(np.int32) + contour = np.transpose(contour, [1, 0]) + contour = contour[:, np.newaxis, :] + A_i = cv2.contourArea(contour) + # break + + if A_i < quadareas: + quadareas = A_i + cnt = contour + return cnt + + +def process_yoloe(im, im_info, resize_shape): + im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) + im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) + # print(im) + + im, im_info = resize(im, im_info, resize_shape, False) + h_n, w_n = im.shape[:-1] + im, im_info = pad(im, im_info, resize_shape) + + # im = normalize_img(im) + im = im / 255.0 + im = im.transpose((2, 0, 1)).copy() + + im = paddle.to_tensor(im, dtype='float32') + im = im.unsqueeze(0) + factor = paddle.to_tensor(im_info['scale_factor']).reshape((1, 2)).astype('float32') + im_shape = paddle.to_tensor(im_info['im_shape'].reshape((1, 2)), dtype='float32') + return im, im_shape, factor + + +def ResizePad1(img, target_size): + h, w = img.shape[:2] + m = max(h, w) + ratio = target_size / m + new_w, new_h = int(ratio * w), int(ratio * h) + img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) + top = (target_size - new_h) // 2 + bottom = (target_size - new_h) - top + left = (target_size - new_w) // 2 + right = (target_size - new_w) - left + img1 = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(255, 255, 255)) + return img1, new_w, new_h, left, top + + +def process_db(im, im_info, resize_shape): + im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32) + im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) + + # im, im_info = resize(im, im_info, resize_shape,False) + + resize_h, resize_w = im.shape[:-1] + h_n, w_n = im.shape[:-1] + im, new_w, new_h, left, top = ResizePad1(im, 800) + + # im = transforms.ToTensor()(im) + im = im / 255.0 + # im = normalize_img(im) + im = im.transpose((2, 0, 1)).copy() + + im = paddle.to_tensor(im, dtype='float32') + + # im = transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])(im) + im = im.unsqueeze(0) + return im, new_h, new_w, left, top + + +def crop_image(img, target_size, center): + width, height = img.shape[1:] + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img[:, int(h_start):int(h_end), int(w_start):int(w_end)] + return img + + +def ResizePad(img, target_size): + img = np.array(img) + h, w = img.shape[:2] + m = max(h, w) + ratio = target_size / m + new_w, new_h = int(ratio * w), int(ratio * h) + img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) + top = (target_size - new_h) // 2 + bottom = (target_size - new_h) - top + left = (target_size - new_w) // 2 + right = (target_size - new_w) - left + img1 = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(255, 255, 255)) + # img1 = cv2.copyMakeBorder(img, top, bottom, left, right,cv2.BORDER_REPLICATE) + return img1 + + +class NormalizeImage(object): + """ normalize image such as substract mean, divide std + """ + + def __init__(self, scale=None, mean=None, std=None, order='chw', **kwargs): + if isinstance(scale, str): + scale = eval(scale) + self.scale = np.float32(scale if scale is not None else 1.0 / 255.0) + mean = mean if mean is not None else [0.485, 0.456, 0.406] + std = std if std is not None else [0.229, 0.224, 0.225] + + shape = (3, 1, 1) if order == 'chw' else (1, 1, 3) + self.mean = np.array(mean).reshape(shape).astype('float32') + self.std = np.array(std).reshape(shape).astype('float32') + + def __call__(self, img): + from PIL import Image + if isinstance(img, Image.Image): + img = np.array(img) + assert isinstance(img, + np.ndarray), "invalid input 'img' in NormalizeImage" + # print(self.std,self.scale,self.mean) + img = ( + img.astype('float32') * self.scale - self.mean) / self.std + return img + + +def process_image1(img, mode, rotate): + resize_width = 624 + img = ResizePad(img, resize_width) + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255.0 + # img = normalize_img1(img) + # # img = crop_image(img, 600,True) + # norm= NormalizeImage() + # img = norm(img) + return img + + +def pad_stride(im): + coarsest_stride = 32 + if coarsest_stride <= 0: + return im + im_c, im_h, im_w = im.shape + pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride) + pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride) + padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32) + padding_im[:, :im_h, :im_w] = im + return padding_im + + +def box_score_fast(bitmap, _box): + ''' + box_score_fast: use bbox mean score as the mean score + ''' + h, w = bitmap.shape[:2] + box = _box.copy() + xmin = np.clip(np.floor(box[:, 0].min()).astype(np.int), 0, w - 1) + xmax = np.clip(np.ceil(box[:, 0].max()).astype(np.int), 0, w - 1) + ymin = np.clip(np.floor(box[:, 1].min()).astype(np.int), 0, h - 1) + ymax = np.clip(np.ceil(box[:, 1].max()).astype(np.int), 0, h - 1) + + mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8) + box[:, 0] = box[:, 0] - xmin + box[:, 1] = box[:, 1] - ymin + cv2.fillPoly(mask, box.reshape(1, -1, 2).astype(np.int32), 1) + return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0] + + +def get_mini_boxes(contour): + bounding_box = cv2.minAreaRect(contour) + points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) + + index_1, index_2, index_3, index_4 = 0, 1, 2, 3 + if points[1][1] > points[0][1]: + index_1 = 0 + index_4 = 1 + else: + index_1 = 1 + index_4 = 0 + if points[3][1] > points[2][1]: + index_2 = 2 + index_3 = 3 + else: + index_2 = 3 + index_3 = 2 + + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] + return box, min(bounding_box[1]) + + +def nms(box_List): + np_boxes = np.array(box_List) + x1 = np_boxes[:, 2] + y1 = np_boxes[:, 3] + x2 = np_boxes[:, 4] + y2 = np_boxes[:, 5] + + areas = (y2 - y1 + 1) * (x2 - x1 + 1) + scores = np_boxes[:, 1] + keep = [] + index = scores.argsort()[::-1] + thresh = 0.01 + while index.size > 0: + i = index[0] # every time the first is the biggst, and add it directly + keep.append(i) + + x11 = np.maximum(x1[i], x1[index[1:]]) # calculate the points of overlap + y11 = np.maximum(y1[i], y1[index[1:]]) + x22 = np.minimum(x2[i], x2[index[1:]]) + y22 = np.minimum(y2[i], y2[index[1:]]) + + w = np.maximum(0, x22 - x11 + 1) # the weights of overlap + h = np.maximum(0, y22 - y11 + 1) # the height of overlap + + overlaps = w * h + + ious = overlaps / (areas[i] + areas[index[1:]] - overlaps) + idx = np.where(ious <= thresh)[0] + + index = index[idx + 1] # because index start from 1 + # break + np_boxes = np_boxes[keep, :] + return np_boxes + + +def expand_poly(data, sec_dis): + """多边形等距缩放 + Args: + data: 多边形按照逆时针顺序排列的的点集 + sec_dis: 缩放距离 + + Returns: + 缩放后的多边形点集 + """ + num = len(data) + scal_data = [] + for i in range(num): + x1 = data[(i) % num][0] - data[(i - 1) % num][0] + y1 = data[(i) % num][1] - data[(i - 1) % num][1] + x2 = data[(i + 1) % num][0] - data[(i) % num][0] + y2 = data[(i + 1) % num][1] - data[(i) % num][1] + + d_A = (x1 ** 2 + y1 ** 2) ** 0.5 + d_B = (x2 ** 2 + y2 ** 2) ** 0.5 + + Vec_Cross = (x1 * y2) - (x2 * y1) + if (d_A * d_B == 0): + continue + sin_theta = Vec_Cross / (d_A * d_B) + if (sin_theta == 0): + continue + dv = sec_dis / sin_theta + + v1_x = (dv / d_A) * x1 + v1_y = (dv / d_A) * y1 + + v2_x = (dv / d_B) * x2 + v2_y = (dv / d_B) * y2 + + PQ_x = v1_x - v2_x + PQ_y = v1_y - v2_y + + Q_x = data[(i) % num][0] + PQ_x + Q_y = data[(i) % num][1] + PQ_y + scal_data.append([Q_x, Q_y]) + return scal_data + + +def step_function(x, y): + return paddle.reciprocal(1 + paddle.exp(-50 * (x - y))) + + +def process(src_image_dir, save_dir): + model = paddle.jit.load('models/obj_det/model') + model.eval() + + + easy_model = paddle.jit.load('models/db_net/model') + easy_model.eval() + + + dir_model = paddle.jit.load('models/pplcnet/model') + dir_model.eval() + + image_paths = glob.glob(os.path.join(src_image_dir, "*.jpg")) + image_paths.extend(glob.glob(os.path.join(src_image_dir, "*.jpeg"))) + image_paths.extend(glob.glob(os.path.join(src_image_dir, "*.png"))) + result = {} + for image_path in image_paths: + + im_info = { + 'scale_factor': np.array( + [1., 1.], dtype=np.float32), + 'im_shape': None, + } + + filename = os.path.split(image_path)[1] + im = cv2.imread(image_path) + im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) + + im = np.asarray(im) + temp = im.copy() + h1, w1 = im.shape[:-1] + im_mask = im.copy() + im, im_shape, factor = process_yoloe(im, im_info, [928, 928]) + print(image_path) + pre = model(im, factor) + if filename not in result: + result[filename] = [] + for item in pre[0].numpy(): + + cls, value, xmin, ymin, xmax, ymax = list(item) + cls, xmin, ymin, xmax, ymax = [int(x) for x in [cls, xmin, ymin, xmax, ymax]] + # xmin, ymin, xmax, ymax = xmin -10, ymin -10, xmax +10, ymax +10 + if xmin < 0: + xmin = 0 + if ymin < 0: + ymin = 0 + if xmax > w1: + xmax = w1 + if ymax > h1: + ymax = h1 + + if value > 0.5: + + im_info = {} + ymin_a = max(ymin - 10, 0) + xmin_a = max(xmin - 10, 0) + ymax_a = min(ymax + 10, h1) + xmax_a = min(xmax + 10, w1) + + crop_img = im_mask[ymin_a:ymax_a, xmin_a:xmax_a, :] + # crop_img = im_mask + ymin_c = max(ymin - 5, 0) + xmin_c = max(xmin - 5, 0) + ymax_c = min(ymax + 5, h1) + xmax_c = min(xmax + 5, w1) + pred_label = 0 + + cls_img = im_mask[ymin_c:ymax_c, xmin_c:xmax_c, :] + + destHeight, destWidth = crop_img.shape[:-1] + crop_img, resize_h, resize_w, left, top = process_db(crop_img, im_info, [800, 800]) + + with paddle.no_grad(): + predicts = easy_model(crop_img) + predict_maps = predicts.cpu() + pred = predict_maps[0, 0].numpy() + segmentation = pred > 0.7 + # print(segmentation.shape) + # dilation_kernel = np.array([[1, 1], [1, 1]]) + mask = np.array(segmentation).astype(np.uint8) + + contours, _ = cv2.findContours((mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + + lt = np.array([int(xmin), int(ymin)]) + rt = np.array([int(xmax), int(ymin)]) + rb = np.array([int(xmax), int(ymax)]) + lb = np.array([int(xmin), int(ymax)]) + + # lt = np.array([int(xmin_a), int(ymin_a)]) + # rt = np.array([int(xmax_a), int(ymin_a)]) + # rb = np.array([int(xmax_a), int(ymax_a)]) + # lb = np.array([int(xmin_a), int(ymax_a)]) + + max_size = 0 + cnt_save = None + for cont in contours: + points, sside = get_mini_boxes(cont) + if sside > max_size: + max_size = sside + cnt_save = cont + # cnt_save = None + if cnt_save is not None: + epsilon = 0.01 * cv2.arcLength(cnt_save, True) + box = cv2.approxPolyDP(cnt_save, epsilon, True) + hull = cv2.convexHull(box) + points, sside = get_mini_boxes(cnt_save) + len_hull = len(hull) + + if len_hull == 4: + target_box = np.array(hull) + elif len_hull > 4: + target_box = minboundquad(hull) + else: + target_box = np.array(points) + + box = np.array(target_box).reshape([-1, 2]) + + # print(box.shape) + box[:, 0] = np.clip( + (np.round(box[:, 0] - left) / resize_w * destWidth), 0, destWidth) + xmin_a + box[:, 1] = np.clip( + (np.round(box[:, 1] - top) / resize_h * destHeight), 0, destHeight) + ymin_a + x = box[:, 0] + l_idx = x.argsort() + l_box = np.array([box[l_idx[0]], box[l_idx[1]]]) + r_box = np.array([box[l_idx[2]], box[l_idx[3]]]) + l_idx_1 = np.array(l_box[:, 1]).argsort() + lt = l_box[l_idx_1[0]] + lt[lt < 0] = 0 + lb = l_box[l_idx_1[1]] + r_idx_1 = np.array(r_box[:, 1]).argsort() + rt = r_box[r_idx_1[0]] + rt[rt < 0] = 0 + rb = r_box[r_idx_1[1]] + cls_box = box.copy() + cls_box[:, 0] = cls_box[:, 0] - xmin_c + cls_box[:, 1] = cls_box[:, 1] - ymin_c + cv2.polylines(cls_img, [np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))], True, + color=(255, 0, 255), thickness=5) + + cls_img = process_image1(cls_img, 'test', True) + cls_img = paddle.to_tensor(cls_img) + cls_img = cls_img.unsqueeze(0) + with paddle.no_grad(): + # print(cls_img.shape) + label = dir_model(cls_img) + label = label.unsqueeze(0).numpy() + mini_batch_result = np.argsort(label) + mini_batch_result = mini_batch_result[0][-1] # 把这些列标拿出来 + mini_batch_result = mini_batch_result.flatten() # 拉平了,只吐出一个 array + mini_batch_result = mini_batch_result[::-1] # 逆序 + + pred_label = mini_batch_result[0] + + if pred_label == 0: + lt1 = lt + rt1 = rt + rb1 = rb + lb1 = lb + elif pred_label == 1: + lt1 = rt + rt1 = rb + rb1 = lb + lb1 = lt + elif pred_label == 2: + lt1 = rb + rt1 = lb + rb1 = lt + lb1 = rt + elif pred_label == 3: + lt1 = lb + rt1 = lt + rb1 = rt + lb1 = rb + else: + lt1 = lt + rt1 = rt + rb1 = rb + lb1 = lb + draw_box = np.array([lt1, rt1, rb1, lb1]).reshape([-1, 2]) + cv2.circle(temp, (int(lt1[0]), int(lt1[1])), 50, (255, 0, 0), 10) + cv2.rectangle(temp, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 10) + cv2.polylines(temp, [np.array(draw_box).astype(np.int32).reshape((-1, 1, 2))], True, + color=(255, 0, 255), thickness=6) + + result[filename].append({ + "box": [int(xmin), int(ymin), int(xmax), int(ymax)], + "lb": [int(lb1[0]), int(lb1[1])], + "lt": [int(lt1[0]), int(lt1[1])], + "rt": [int(rt1[0]), int(rt1[1])], + "rb": [int(rb1[0]), int(rb1[1])], + }) + print(f"{image_path} process done!") + + save_p = os.path.join(save_dir, filename) + h, w = temp.shape[:-1] + target = 512 + w_p = h / w * target + # temp = cv2.resize(temp, (int(target), int(w_p))) + cv2.imwrite(save_p, temp) + + with open(os.path.join(save_dir, "result.txt"), 'w', encoding="utf-8") as f: + f.write(json.dumps(result)) + + +if __name__ == "__main__": + + src_image_dir = "images" + save_dir = "outputs" + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + process(src_image_dir, save_dir) \ No newline at end of file diff --git a/rapid_table_det_paddle/predictor.py b/paddle_predictor.py similarity index 73% rename from rapid_table_det_paddle/predictor.py rename to paddle_predictor.py index 39cbc72..b58e19d 100644 --- a/rapid_table_det_paddle/predictor.py +++ b/paddle_predictor.py @@ -1,72 +1,75 @@ -import time +import os + +import cv2 import paddle -from rapid_table_det_paddle.utils import * +import math +import itertools +import numpy as np +from PIL import Image +from utils import * MODEL_STAGES_PATTERN = { "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] } +cur_dir = Path(__file__).resolve().parent +cur_dir_str = str(cur_dir) +obj_model_path = f"{cur_dir_str}/models/obj_det/model" +dbnet_model_path = f"{cur_dir_str}/models/db_net/model" +pplcnet_model_path = f"{cur_dir_str}/models/pplcnet/model" class ObjectDetector: - model_key = "obj_det_paddle" - - def __init__(self, model_path, **kwargs): + def __init__(self, model_path=obj_model_path, **kwargs): self.model = paddle.jit.load(model_path) self.img_loader = LoadImage() self.resize_shape = [928, 928] def __call__(self, img, **kwargs): - start = time.time() score = kwargs.get("score", 0.4) img = self.img_loader(img) ori_h, ori_w = img.shape[:-1] img, im_shape, factor = self.img_preprocess(img, self.resize_shape) - pre = self.model(img, factor) + pre = self.model(img) result = [] for item in pre[0].numpy(): cls, value, xmin, ymin, xmax, ymax = list(item) if value < score: continue - cls, xmin, ymin, xmax, ymax = [ - int(x) for x in [cls, xmin, ymin, xmax, ymax] - ] + cls, xmin, ymin, xmax, ymax = [int(x) for x in [cls, xmin, ymin, xmax, ymax]] xmin = max(xmin, 0) ymin = max(ymin, 0) xmax = min(xmax, ori_w) ymax = min(ymax, ori_h) result.append([value, np.array([xmin, ymin, xmax, ymax])]) - return result, time.time() - start + return result def img_preprocess(self, img, resize_shape=[928, 928]): + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) im_info = { - "scale_factor": np.array([1.0, 1.0], dtype=np.float32), - "im_shape": np.array(img.shape[:2], dtype=np.float32), + 'scale_factor': np.array( + [1., 1.], dtype=np.float32), + 'im_shape': None, } + im_info['im_shape'] = np.array(img.shape[:2], dtype=np.float32) + im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32) im, im_info = resize(img, im_info, resize_shape, False) im, im_info = pad(im, im_info, resize_shape) im = im / 255.0 im = im.transpose((2, 0, 1)).copy() - im = paddle.to_tensor(im, dtype="float32") + im = paddle.to_tensor(im, dtype='float32') im = im.unsqueeze(0) - factor = ( - paddle.to_tensor(im_info["scale_factor"]).reshape((1, 2)).astype("float32") - ) - im_shape = paddle.to_tensor( - im_info["im_shape"].reshape((1, 2)), dtype="float32" - ) + factor = paddle.to_tensor(im_info['scale_factor']).reshape((1, 2)).astype('float32') + im_shape = paddle.to_tensor(im_info['im_shape'].reshape((1, 2)), dtype='float32') return im, im_shape, factor class DbNet: - model_key = "edge_det_paddle" - - def __init__(self, model_path, **kwargs): + def __init__(self, model_path=dbnet_model_path, **kwargs): self.model = paddle.jit.load(model_path) self.img_loader = LoadImage() self.resize_shape = [800, 800] def __call__(self, img, **kwargs): - start = time.time() img = self.img_loader(img) destHeight, destWidth = img.shape[:-1] img, resize_h, resize_w, left, top = self.img_preprocess(img, self.resize_shape) @@ -81,18 +84,14 @@ def __call__(self, img, **kwargs): # todo 注意还有crop的偏移 if box is not None: # 根据缩放调整坐标适配输入的img大小 - adjusted_box = self.adjust_coordinates( - box, left, top, resize_w, resize_h, destWidth, destHeight - ) + adjusted_box = self.adjust_coordinates(box, left, top, resize_w, resize_h, destWidth, destHeight) # 排序并裁剪负值 lt, lb, rt, rb = self.sort_and_clip_coordinates(adjusted_box) - return box, lt, lb, rt, rb, time.time() - start + return box, lt, lb, rt, rb else: - return None, None, None, None, None, time.time() - start + return None - def adjust_coordinates( - self, box, left, top, resize_w, resize_h, destWidth, destHeight - ): + def adjust_coordinates(self, box, left, top, resize_w, resize_h, destWidth, destHeight): """ 调整边界框坐标,确保它们在合理范围内。 @@ -112,13 +111,11 @@ def adjust_coordinates( """ # 调整横坐标 box[:, 0] = np.clip( - (np.round(box[:, 0] - left) / resize_w * destWidth), 0, destWidth - ) + (np.round(box[:, 0] - left) / resize_w * destWidth), 0, destWidth) # 调整纵坐标 box[:, 1] = np.clip( - (np.round(box[:, 1] - top) / resize_h * destHeight), 0, destHeight - ) + (np.round(box[:, 1] - top) / resize_h * destHeight), 0, destHeight) return box def sort_and_clip_coordinates(self, box): @@ -156,27 +153,27 @@ def sort_and_clip_coordinates(self, box): return lt, lb, rt, rb def img_preprocess(self, img, resize_shape=[800, 800]): + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) im = im / 255.0 im = im.transpose((2, 0, 1)).copy() - im = paddle.to_tensor(im, dtype="float32") + im = paddle.to_tensor(im, dtype='float32') im = im.unsqueeze(0) return im, new_h, new_w, left, top class PPLCNet: - model_key = "cls_det_paddle" - - def __init__(self, model_path, **kwargs): + def __init__(self, model_path=dbnet_model_path, **kwargs): self.model = paddle.jit.load(model_path) self.img_loader = LoadImage() self.resize_shape = [624, 624] def __call__(self, img, **kwargs): - start = time.time() img = self.img_loader(img) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) img = self.img_preprocess(img, self.resize_shape) with paddle.no_grad(): + # print(cls_img.shape) label = self.model(img) label = label.unsqueeze(0).numpy() mini_batch_result = np.argsort(label) @@ -184,10 +181,11 @@ def __call__(self, img, **kwargs): mini_batch_result = mini_batch_result.flatten() # 拉平了,只吐出一个 array mini_batch_result = mini_batch_result[::-1] # 逆序 pred_label = mini_batch_result[0] - return pred_label, time.time() - start + return pred_label def img_preprocess(self, img, resize_shape=[624, 624]): + # resize_width = 624 im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) - im = np.array(im).astype("float32").transpose((2, 0, 1)) / 255.0 + im = np.array(img).astype('float32').transpose((2, 0, 1)) / 255.0 im = paddle.to_tensor(im) return im.unsqueeze(0) diff --git a/rapid_table_det/__init__.py b/rapid_table_det/__init__.py deleted file mode 100644 index 9d4f418..0000000 --- a/rapid_table_det/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: Jocker1212 -# @Contact: xinyijianggo@gmail.com -from .inference import TableDetector diff --git a/rapid_table_det/inference.py b/rapid_table_det/inference.py deleted file mode 100644 index b72d36a..0000000 --- a/rapid_table_det/inference.py +++ /dev/null @@ -1,233 +0,0 @@ -import os -from pathlib import Path -from typing import Union - -import cv2 -import numpy as np - -from .predictor import DbNet, PaddleYoloEDet, PPLCNet, YoloSeg, YoloDet -from .utils.download_model import DownloadModel - -from .utils.logger import get_logger -from .utils.load_image import LoadImage - -root_dir = Path(__file__).resolve().parent -model_dir = os.path.join(root_dir, "models") - -ROOT_DIR = Path(__file__).resolve().parent -logger = get_logger("rapid_layout") - -ROOT_URL = "https://www.modelscope.cn/models/jockerK/TableExtractor/resolve/master/rapid_table_det/models/" -KEY_TO_MODEL_URL = { - "yolo_obj_det": f"{ROOT_URL}/yolo_obj_det.onnx", - "yolo_edge_det": f"{ROOT_URL}/yolo_edge_det.onnx", - "yolo_edge_det_s": f"{ROOT_URL}/yolo_edge_det_s.onnx", - "paddle_obj_det": f"{ROOT_URL}/paddle_obj_det.onnx", - "paddle_obj_det_s": f"{ROOT_URL}/paddle_obj_det_s.onnx", - "paddle_edge_det": f"{ROOT_URL}/paddle_edge_det.onnx", - "paddle_edge_det_s": f"{ROOT_URL}/paddle_edge_det_s.onnx", - "paddle_cls_det": f"{ROOT_URL}/paddle_cls_det.onnx", -} - - -class TableDetector: - def __init__( - self, - use_cuda=False, - use_dml=False, - obj_model_path=None, - edge_model_path=None, - cls_model_path=None, - obj_model_type="yolo_obj_det", - edge_model_type="yolo_edge_det", - cls_model_type="paddle_cls_det", - ): - self.img_loader = LoadImage() - obj_det_config = { - "model_path": self.get_model_path(obj_model_type, obj_model_path), - "use_cuda": use_cuda, - "use_dml": use_dml, - } - edge_det_config = { - "model_path": self.get_model_path(edge_model_type, edge_model_path), - "use_cuda": use_cuda, - "use_dml": use_dml, - } - cls_det_config = { - "model_path": self.get_model_path(cls_model_type, cls_model_path), - "use_cuda": use_cuda, - "use_dml": use_dml, - } - if "yolo" in obj_model_type: - self.obj_detector = YoloDet(obj_det_config) - else: - self.obj_detector = PaddleYoloEDet(obj_det_config) - if "yolo" in edge_model_type: - self.dbnet = YoloSeg(edge_det_config) - else: - self.dbnet = DbNet(edge_det_config) - if "yolo" in cls_model_type: - self.pplcnet = PPLCNet(cls_det_config) - else: - self.pplcnet = PPLCNet(cls_det_config) - - def __call__( - self, - img, - det_accuracy=0.7, - use_obj_det=True, - use_edge_det=True, - use_cls_det=True, - ): - img = self.img_loader(img) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img_mask = img.copy() - h, w = img.shape[:-1] - obj_det_res, pred_label = self.init_default_output(h, w) - result = [] - obj_det_elapse, edge_elapse, rotate_det_elapse = 0, 0, 0 - if use_obj_det: - obj_det_res, obj_det_elapse = self.obj_detector(img, score=det_accuracy) - for i in range(len(obj_det_res)): - det_res = obj_det_res[i] - score, box = det_res - xmin, ymin, xmax, ymax = box - edge_box = box.reshape([-1, 2]) - lb, lt, rb, rt = self.get_box_points(box) - if use_edge_det: - xmin_edge, ymin_edge, xmax_edge, ymax_edge = self.pad_box_points( - h, w, xmax, xmin, ymax, ymin, 10 - ) - crop_img = img_mask[ymin_edge:ymax_edge, xmin_edge:xmax_edge, :] - edge_box, lt, lb, rt, rb, tmp_edge_elapse = self.dbnet(crop_img) - edge_elapse += tmp_edge_elapse - if edge_box is None: - continue - lb, lt, rb, rt = self.adjust_edge_points_axis( - edge_box, lb, lt, rb, rt, xmin_edge, ymin_edge - ) - if use_cls_det: - xmin_cls, ymin_cls, xmax_cls, ymax_cls = self.pad_box_points( - h, w, xmax, xmin, ymax, ymin, 5 - ) - cls_img = img_mask[ymin_cls:ymax_cls, xmin_cls:xmax_cls, :] - # 增加先验信息 - self.add_pre_info_for_cls(cls_img, edge_box, xmin_cls, ymin_cls) - pred_label, tmp_rotate_det_elapse = self.pplcnet(cls_img) - rotate_det_elapse += tmp_rotate_det_elapse - lb1, lt1, rb1, rt1 = self.get_real_rotated_points( - lb, lt, pred_label, rb, rt - ) - result.append( - { - "box": [int(xmin), int(ymin), int(xmax), int(ymax)], - "lb": [int(lb1[0]), int(lb1[1])], - "lt": [int(lt1[0]), int(lt1[1])], - "rt": [int(rt1[0]), int(rt1[1])], - "rb": [int(rb1[0]), int(rb1[1])], - } - ) - elapse = [obj_det_elapse, edge_elapse, rotate_det_elapse] - return result, elapse - - def init_default_output(self, h, w): - img_box = np.array([0, 0, w, h]) - # 初始化默认值 - obj_det_res, edge_box, pred_label = ( - [[1.0, img_box]], - img_box.reshape([-1, 2]), - 0, - ) - return obj_det_res, pred_label - - def add_pre_info_for_cls(self, cls_img, edge_box, xmin_cls, ymin_cls): - """ - Args: - cls_img: - edge_box: - xmin_cls: - ymin_cls: - - Returns: 带边缘划线的图片,给方向分类提供先验信息 - - """ - cls_box = edge_box.copy() - cls_box[:, 0] = cls_box[:, 0] - xmin_cls - cls_box[:, 1] = cls_box[:, 1] - ymin_cls - # 画框增加先验信息,辅助方向label识别 - cv2.polylines( - cls_img, - [np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))], - True, - color=(255, 0, 255), - thickness=5, - ) - - def adjust_edge_points_axis(self, edge_box, lb, lt, rb, rt, xmin_edge, ymin_edge): - edge_box[:, 0] += xmin_edge - edge_box[:, 1] += ymin_edge - lt, lb, rt, rb = ( - lt + [xmin_edge, ymin_edge], - lb + [xmin_edge, ymin_edge], - rt + [xmin_edge, ymin_edge], - rb + [xmin_edge, ymin_edge], - ) - return lb, lt, rb, rt - - def get_box_points(self, img_box): - x1, y1, x2, y2 = img_box - lt = np.array([x1, y1]) # 左上角 - rt = np.array([x2, y1]) # 右上角 - rb = np.array([x2, y2]) # 右下角 - lb = np.array([x1, y2]) # 左下角 - return lb, lt, rb, rt - - def get_real_rotated_points(self, lb, lt, pred_label, rb, rt): - if pred_label == 0: - lt1 = lt - rt1 = rt - rb1 = rb - lb1 = lb - elif pred_label == 1: - lt1 = rt - rt1 = rb - rb1 = lb - lb1 = lt - elif pred_label == 2: - lt1 = rb - rt1 = lb - rb1 = lt - lb1 = rt - elif pred_label == 3: - lt1 = lb - rt1 = lt - rb1 = rt - lb1 = rb - else: - lt1 = lt - rt1 = rt - rb1 = rb - lb1 = lb - return lb1, lt1, rb1, rt1 - - def pad_box_points(self, h, w, xmax, xmin, ymax, ymin, pad): - ymin_edge = max(ymin - pad, 0) - xmin_edge = max(xmin - pad, 0) - ymax_edge = min(ymax + pad, h) - xmax_edge = min(xmax + pad, w) - return xmin_edge, ymin_edge, xmax_edge, ymax_edge - - @staticmethod - def get_model_path(model_type: str, model_path: Union[str, Path, None]) -> str: - if model_path is not None: - return model_path - - model_url = KEY_TO_MODEL_URL.get(model_type, None) - if model_url: - model_path = DownloadModel.download(model_url) - return model_path - - logger.info( - "model url is None, using the default download model %s", model_path - ) - return model_path diff --git a/rapid_table_det/models/.gitkeep b/rapid_table_det/models/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/rapid_table_det/predictor.py b/rapid_table_det/predictor.py deleted file mode 100644 index 0c61de1..0000000 --- a/rapid_table_det/predictor.py +++ /dev/null @@ -1,317 +0,0 @@ -import time -from pathlib import Path - -import cv2 -import numpy as np -from typing import Dict, Any - -from .utils.infer_engine import OrtInferSession -from .utils.load_image import LoadImage -from .utils.transform import ( - custom_NMSBoxes, - resize, - pad, - ResizePad, - sigmoid, - get_max_adjacent_bbox, -) - -MODEL_STAGES_PATTERN = { - "PPLCNet": ["blocks2", "blocks3", "blocks4", "blocks5", "blocks6"] -} -root_dir = Path(__file__).resolve().parent -root_dir_str = str(root_dir) - - -class PaddleYoloEDet: - model_key = "obj_det" - - def __init__(self, config: Dict[str, Any]): - self.model = OrtInferSession(config) - self.img_loader = LoadImage() - self.resize_shape = [928, 928] - - def __call__(self, img, **kwargs): - start = time.time() - score = kwargs.get("score", 0.4) - img = self.img_loader(img) - ori_h, ori_w = img.shape[:-1] - img, im_shape, factor = self.img_preprocess(img, self.resize_shape) - pre = self.model([img, factor]) - result = self.img_postprocess(ori_h, ori_w, pre, score) - return result, time.time() - start - - def img_postprocess(self, ori_h, ori_w, pre, score): - result = [] - for item in pre[0]: - cls, value, xmin, ymin, xmax, ymax = list(item) - if value < score: - continue - cls, xmin, ymin, xmax, ymax = [ - int(x) for x in [cls, xmin, ymin, xmax, ymax] - ] - xmin = max(xmin, 0) - ymin = max(ymin, 0) - xmax = min(xmax, ori_w) - ymax = min(ymax, ori_h) - result.append([value, np.array([xmin, ymin, xmax, ymax])]) - return result - - def img_preprocess(self, img, resize_shape=[928, 928]): - im_info = { - "scale_factor": np.array([1.0, 1.0], dtype=np.float32), - "im_shape": np.array(img.shape[:2], dtype=np.float32), - } - im, im_info = resize(img, im_info, resize_shape, False) - im, im_info = pad(im, im_info, resize_shape) - im = im / 255.0 - im = im.transpose((2, 0, 1)).copy() - im = im[None, :] - factor = im_info["scale_factor"].reshape((1, 2)) - im_shape = im_info["im_shape"].reshape((1, 2)) - return im, im_shape, factor - - -class YoloDet: - def __init__(self, config: Dict[str, Any]): - self.model = OrtInferSession(config) - self.img_loader = LoadImage() - self.resize_shape = [928, 928] - - def __call__(self, img, **kwargs): - start = time.time() - score = kwargs.get("score", 0.4) - img = self.img_loader(img) - ori_h, ori_w = img.shape[:-1] - img, new_w, new_h, left, top = self.img_preprocess(img, self.resize_shape) - pre = self.model([img]) - result = self.img_postprocess( - pre, ori_w / new_w, ori_h / new_h, left, top, score - ) - return result, time.time() - start - - def img_preprocess(self, img, resize_shape=[928, 928]): - im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) - im = im / 255.0 - im = im.transpose((2, 0, 1)).copy() - im = im[None, :].astype("float32") - return im, new_w, new_h, left, top - - def img_postprocess(self, predict_maps, x_factor, y_factor, left, top, score): - result = [] - # 转置和压缩输出以匹配预期的形状 - outputs = np.transpose(np.squeeze(predict_maps[0])) - # 获取输出数组的行数 - rows = outputs.shape[0] - # 用于存储检测的边界框、得分和类别ID的列表 - boxes = [] - scores = [] - # 遍历输出数组的每一行 - for i in range(rows): - # 找到类别得分中的最大得分 - max_score = outputs[i][4] - # 如果最大得分高于置信度阈值 - if max_score >= score: - # 从当前行提取边界框坐标 - x, y, w, h = outputs[i][0], outputs[i][1], outputs[i][2], outputs[i][3] - # 计算边界框的缩放坐标 - xmin = max(int((x - w / 2 - left) * x_factor), 0) - ymin = max(int((y - h / 2 - top) * y_factor), 0) - xmax = xmin + int(w * x_factor) - ymax = ymin + int(h * y_factor) - # 将类别ID、得分和框坐标添加到各自的列表中 - boxes.append([xmin, ymin, xmax, ymax]) - scores.append(max_score) - # 应用非最大抑制过滤重叠的边界框 - indices = custom_NMSBoxes(boxes, scores) - for i in indices: - result.append([scores[i], np.array(boxes[i])]) - return result - - -class DbNet: - model_key = "edge_det" - - def __init__(self, config: Dict[str, Any]): - self.model = OrtInferSession(config) - self.img_loader = LoadImage() - self.resize_shape = [800, 800] - - def __call__(self, img, **kwargs): - start = time.time() - img = self.img_loader(img) - destHeight, destWidth = img.shape[:-1] - img, resize_h, resize_w, left, top = self.img_preprocess(img, self.resize_shape) - # with paddle.no_grad(): - predict_maps = self.model([img]) - pred = self.img_postprocess(predict_maps) - if pred is None: - return None, None, None, None, None, time.time() - start - segmentation = pred > 0.8 - mask = np.array(segmentation).astype(np.uint8) - # 找到最佳边缘box shape(4, 2) - box = get_max_adjacent_bbox(mask) - # todo 注意还有crop的偏移 - if box is not None: - # 根据缩放调整坐标适配输入的img大小 - adjusted_box = self.adjust_coordinates( - box, left, top, resize_w, resize_h, destWidth, destHeight - ) - # 排序并裁剪负值 - lt, lb, rt, rb = self.sort_and_clip_coordinates(adjusted_box) - return box, lt, lb, rt, rb, time.time() - start - else: - return None, None, None, None, None, time.time() - start - - def img_postprocess(self, predict_maps): - pred = np.squeeze(predict_maps[0]) - return pred - - def adjust_coordinates( - self, box, left, top, resize_w, resize_h, destWidth, destHeight - ): - """ - 调整边界框坐标,确保它们在合理范围内。 - - 参数: - box (numpy.ndarray): 原始边界框坐标 (shape: (4, 2)) - left (int): 左侧偏移量 - top (int): 顶部偏移量 - resize_w (int): 缩放宽度 - resize_h (int): 缩放高度 - destWidth (int): 目标宽度 - destHeight (int): 目标高度 - xmin_a (int): 目标左上角横坐标 - ymin_a (int): 目标左上角纵坐标 - - 返回: - numpy.ndarray: 调整后的边界框坐标 - """ - # 调整横坐标 - box[:, 0] = np.clip( - (np.round(box[:, 0] - left) / resize_w * destWidth), 0, destWidth - ) - - # 调整纵坐标 - box[:, 1] = np.clip( - (np.round(box[:, 1] - top) / resize_h * destHeight), 0, destHeight - ) - return box - - def sort_and_clip_coordinates(self, box): - """ - 对边界框坐标进行排序并裁剪负值。 - - 参数: - box (numpy.ndarray): 边界框坐标 (shape: (4, 2)) - - 返回: - tuple: 左上角、左下角、右上角、右下角坐标 - """ - # 按横坐标排序 - x = box[:, 0] - l_idx = x.argsort() - l_box = np.array([box[l_idx[0]], box[l_idx[1]]]) - r_box = np.array([box[l_idx[2]], box[l_idx[3]]]) - - # 左侧坐标按纵坐标排序 - l_idx_1 = np.array(l_box[:, 1]).argsort() - lt = l_box[l_idx_1[0]] - lb = l_box[l_idx_1[1]] - - # 右侧坐标按纵坐标排序 - r_idx_1 = np.array(r_box[:, 1]).argsort() - rt = r_box[r_idx_1[0]] - rb = r_box[r_idx_1[1]] - - # 裁剪负值 - lt[lt < 0] = 0 - lb[lb < 0] = 0 - rt[rt < 0] = 0 - rb[rb < 0] = 0 - - return lt, lb, rt, rb - - def img_preprocess(self, img, resize_shape=[800, 800]): - im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) - im = im / 255.0 - im = im.transpose((2, 0, 1)).copy() - im = im[None, :].astype("float32") - return im, new_h, new_w, left, top - - -class YoloSeg(DbNet): - model_key = "edge_det" - - def img_postprocess(self, predict_maps): - box_output = predict_maps[0] - mask_output = predict_maps[1] - predictions = np.squeeze(box_output).T - # Filter out object confidence scores below threshold - scores = predictions[:, 4] - # 获取得分最高的索引 - highest_score_index = scores.argmax() - # 获取得分最高的预测结果 - highest_score_prediction = predictions[highest_score_index] - x, y, w, h = highest_score_prediction[:4] - highest_score = highest_score_prediction[4] - if highest_score < 0.7: - return None - mask_predictions = highest_score_prediction[5:] - mask_predictions = np.expand_dims(mask_predictions, axis=0) - mask_output = np.squeeze(mask_output) - # Calculate the mask maps for each box - num_mask, mask_height, mask_width = mask_output.shape # CHW - masks = sigmoid(mask_predictions @ mask_output.reshape((num_mask, -1))) - masks = masks.reshape((-1, mask_height, mask_width)) - # 提取第一个通道 - mask = masks[0] - - # 计算缩小后的区域边界 - small_w = 200 - small_h = 200 - small_x_min = max(0, int((x - w / 2) * small_w / 800)) - small_x_max = min(small_w, int((x + w / 2) * small_w / 800)) - small_y_min = max(0, int((y - h / 2) * small_h / 800)) - small_y_max = min(small_h, int((y + h / 2) * small_h / 800)) - - # 创建一个全零的掩码 - filtered_mask = np.zeros((small_h, small_w), dtype=np.float32) - - # 将区域内的值复制到过滤后的掩码中 - filtered_mask[small_y_min:small_y_max, small_x_min:small_x_max] = mask[ - small_y_min:small_y_max, small_x_min:small_x_max - ] - - # 使用 OpenCV 进行放大,保持边缘清晰 - resized_mask = cv2.resize( - filtered_mask, (800, 800), interpolation=cv2.INTER_CUBIC - ) - return resized_mask - - -class PPLCNet: - model_key = "cls_det" - - def __init__(self, config: Dict[str, Any]): - self.model = OrtInferSession(config) - self.img_loader = LoadImage() - self.resize_shape = [624, 624] - - def __call__(self, img, **kwargs): - start = time.time() - img = self.img_loader(img) - img = self.img_preprocess(img, self.resize_shape) - label = self.model([img])[0] - label = label[None, :] - mini_batch_result = np.argsort(label) - mini_batch_result = mini_batch_result[0][-1] # 把这些列标拿出来 - mini_batch_result = mini_batch_result.flatten() # 拉平了,只吐出一个 array - mini_batch_result = mini_batch_result[::-1] # 逆序 - pred_label = mini_batch_result[0] - return pred_label, time.time() - start - - def img_preprocess(self, img, resize_shape=[624, 624]): - im, new_w, new_h, left, top = ResizePad(img, resize_shape[0]) - im = np.array(im).transpose((2, 0, 1)) / 255.0 - return im[None, :].astype("float32") diff --git a/rapid_table_det/requirments.txt b/rapid_table_det/requirments.txt deleted file mode 100644 index 05f1a3d..0000000 --- a/rapid_table_det/requirments.txt +++ /dev/null @@ -1,5 +0,0 @@ -numpy -Pillow -opencv-python -onnxruntime -requests diff --git a/rapid_table_det/utils/__init__.py b/rapid_table_det/utils/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/rapid_table_det/utils/download_model.py b/rapid_table_det/utils/download_model.py deleted file mode 100644 index d694fd1..0000000 --- a/rapid_table_det/utils/download_model.py +++ /dev/null @@ -1,59 +0,0 @@ -import io -from pathlib import Path -from typing import Optional, Union - -import requests -from tqdm import tqdm - -from .logger import get_logger - -logger = get_logger("DownloadModel") -CUR_DIR = Path(__file__).resolve() -PROJECT_DIR = CUR_DIR.parent.parent - - -class DownloadModel: - cur_dir = PROJECT_DIR - - @classmethod - def download(cls, model_full_url: Union[str, Path]) -> str: - save_dir = cls.cur_dir / "models" - save_dir.mkdir(parents=True, exist_ok=True) - - model_name = Path(model_full_url).name - save_file_path = save_dir / model_name - if save_file_path.exists(): - logger.debug("%s already exists", save_file_path) - return str(save_file_path) - - try: - logger.info("Download %s to %s", model_full_url, save_dir) - file = cls.download_as_bytes_with_progress(model_full_url, model_name) - cls.save_file(save_file_path, file) - except Exception as exc: - raise DownloadModelError from exc - return str(save_file_path) - - @staticmethod - def download_as_bytes_with_progress( - url: Union[str, Path], name: Optional[str] = None - ) -> bytes: - resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180) - total = int(resp.headers.get("content-length", 0)) - bio = io.BytesIO() - with tqdm( - desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024 - ) as pbar: - for chunk in resp.iter_content(chunk_size=65536): - pbar.update(len(chunk)) - bio.write(chunk) - return bio.getvalue() - - @staticmethod - def save_file(save_path: Union[str, Path], file: bytes): - with open(save_path, "wb") as f: - f.write(file) - - -class DownloadModelError(Exception): - pass diff --git a/rapid_table_det/utils/infer_engine.py b/rapid_table_det/utils/infer_engine.py deleted file mode 100644 index ae02303..0000000 --- a/rapid_table_det/utils/infer_engine.py +++ /dev/null @@ -1,227 +0,0 @@ -from .logger import get_logger -import os -import platform -import traceback -from enum import Enum -from pathlib import Path -from typing import Any, Dict, List, Tuple, Union - -import numpy as np -from onnxruntime import ( - GraphOptimizationLevel, - InferenceSession, - SessionOptions, - get_available_providers, - get_device, -) - - -class EP(Enum): - CPU_EP = "CPUExecutionProvider" - CUDA_EP = "CUDAExecutionProvider" - DIRECTML_EP = "DmlExecutionProvider" - - -class OrtInferSession: - def __init__(self, config: Dict[str, Any]): - self.logger = get_logger("OrtInferSession") - - model_path = config.get("model_path", None) - self._verify_model(model_path) - - self.cfg_use_cuda = config.get("use_cuda", None) - self.cfg_use_dml = config.get("use_dml", None) - - self.had_providers: List[str] = get_available_providers() - EP_list = self._get_ep_list() - - sess_opt = self._init_sess_opts(config) - self.session = InferenceSession( - model_path, - sess_options=sess_opt, - providers=EP_list, - ) - self._verify_providers() - - @staticmethod - def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions: - sess_opt = SessionOptions() - sess_opt.log_severity_level = 4 - sess_opt.enable_cpu_mem_arena = False - sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL - - cpu_nums = os.cpu_count() - intra_op_num_threads = config.get("intra_op_num_threads", -1) - if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums: - sess_opt.intra_op_num_threads = intra_op_num_threads - - inter_op_num_threads = config.get("inter_op_num_threads", -1) - if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums: - sess_opt.inter_op_num_threads = inter_op_num_threads - - return sess_opt - - def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]: - cpu_provider_opts = { - "arena_extend_strategy": "kSameAsRequested", - } - EP_list = [(EP.CPU_EP.value, cpu_provider_opts)] - - cuda_provider_opts = { - "device_id": 0, - "arena_extend_strategy": "kNextPowerOfTwo", - "cudnn_conv_algo_search": "EXHAUSTIVE", - "do_copy_in_default_stream": True, - } - self.use_cuda = self._check_cuda() - if self.use_cuda: - EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts)) - - self.use_directml = self._check_dml() - if self.use_directml: - self.logger.info( - "Windows 10 or above detected, try to use DirectML as primary provider" - ) - directml_options = ( - cuda_provider_opts if self.use_cuda else cpu_provider_opts - ) - EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options)) - return EP_list - - def _check_cuda(self) -> bool: - if not self.cfg_use_cuda: - return False - - cur_device = get_device() - if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers: - return True - - self.logger.warning( - "%s is not in available providers (%s). Use %s inference by default.", - EP.CUDA_EP.value, - self.had_providers, - self.had_providers[0], - ) - self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.") - self.logger.info( - "(For reference only) If you want to use GPU acceleration, you must do:" - ) - self.logger.info( - "First, uninstall all onnxruntime pakcages in current environment." - ) - self.logger.info( - "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`." - ) - self.logger.info( - "\tNote the onnxruntime-gpu version must match your cuda and cudnn version." - ) - self.logger.info( - "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html" - ) - self.logger.info( - "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']", - EP.CUDA_EP.value, - ) - return False - - def _check_dml(self) -> bool: - if not self.cfg_use_dml: - return False - - cur_os = platform.system() - if cur_os != "Windows": - self.logger.warning( - "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.", - cur_os, - self.had_providers[0], - ) - return False - - cur_window_version = int(platform.release().split(".")[0]) - if cur_window_version < 10: - self.logger.warning( - "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.", - cur_window_version, - self.had_providers[0], - ) - return False - - if EP.DIRECTML_EP.value in self.had_providers: - return True - - self.logger.warning( - "%s is not in available providers (%s). Use %s inference by default.", - EP.DIRECTML_EP.value, - self.had_providers, - self.had_providers[0], - ) - self.logger.info("If you want to use DirectML acceleration, you must do:") - self.logger.info( - "First, uninstall all onnxruntime pakcages in current environment." - ) - self.logger.info( - "Second, install onnxruntime-directml by `pip install onnxruntime-directml`" - ) - self.logger.info( - "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']", - EP.DIRECTML_EP.value, - ) - return False - - def _verify_providers(self): - session_providers = self.session.get_providers() - first_provider = session_providers[0] - - if self.use_cuda and first_provider != EP.CUDA_EP.value: - self.logger.warning( - "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.", - EP.CUDA_EP.value, - first_provider, - ) - - if self.use_directml and first_provider != EP.DIRECTML_EP.value: - self.logger.warning( - "%s is not available for current env, the inference part is automatically shifted to be executed under %s.", - EP.DIRECTML_EP.value, - first_provider, - ) - - def __call__(self, input_content: List[np.ndarray]) -> np.ndarray: - input_dict = dict(zip(self.get_input_names(), input_content)) - try: - return self.session.run(None, input_dict) - except Exception as e: - error_info = traceback.format_exc() - raise ONNXRuntimeError(error_info) from e - - def get_input_names(self) -> List[str]: - return [v.name for v in self.session.get_inputs()] - - def get_output_names(self) -> List[str]: - return [v.name for v in self.session.get_outputs()] - - def get_character_list(self, key: str = "character") -> List[str]: - meta_dict = self.session.get_modelmeta().custom_metadata_map - return meta_dict[key].splitlines() - - def have_key(self, key: str = "character") -> bool: - meta_dict = self.session.get_modelmeta().custom_metadata_map - if key in meta_dict.keys(): - return True - return False - - @staticmethod - def _verify_model(model_path: Union[str, Path, None]): - if model_path is None: - raise ValueError("model_path is None!") - - model_path = Path(model_path) - if not model_path.exists(): - raise FileNotFoundError(f"{model_path} does not exists.") - - if not model_path.is_file(): - raise FileExistsError(f"{model_path} is not a file.") - - -class ONNXRuntimeError(Exception): - pass diff --git a/rapid_table_det/utils/load_image.py b/rapid_table_det/utils/load_image.py deleted file mode 100644 index f34b549..0000000 --- a/rapid_table_det/utils/load_image.py +++ /dev/null @@ -1,123 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: SWHL -# @Contact: liekkaskono@163.com -from io import BytesIO -from pathlib import Path -from typing import Any, Union - -import cv2 -import numpy as np -from PIL import Image, UnidentifiedImageError - -root_dir = Path(__file__).resolve().parent -InputType = Union[str, np.ndarray, bytes, Path, Image.Image] - - -class LoadImage: - def __init__(self): - pass - - def __call__(self, img: InputType) -> np.ndarray: - if not isinstance(img, InputType.__args__): - raise LoadImageError( - f"The img type {type(img)} does not in {InputType.__args__}" - ) - - origin_img_type = type(img) - img = self.load_img(img) - img = self.convert_img(img, origin_img_type) - return img - - def load_img(self, img: InputType) -> np.ndarray: - if isinstance(img, (str, Path)): - self.verify_exist(img) - try: - img = self.img_to_ndarray(Image.open(img)) - except UnidentifiedImageError as e: - raise LoadImageError(f"cannot identify image file {img}") from e - return img - - if isinstance(img, bytes): - img = self.img_to_ndarray(Image.open(BytesIO(img))) - return img - - if isinstance(img, np.ndarray): - return img - - if isinstance(img, Image.Image): - return self.img_to_ndarray(img) - - raise LoadImageError(f"{type(img)} is not supported!") - - def img_to_ndarray(self, img: Image.Image) -> np.ndarray: - if img.mode == "1": - img = img.convert("L") - return np.array(img) - return np.array(img) - - def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray: - if img.ndim == 2: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if img.ndim == 3: - channel = img.shape[2] - if channel == 1: - return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) - - if channel == 2: - return self.cvt_two_to_three(img) - - if channel == 3: - if issubclass(origin_img_type, (str, Path, bytes, Image.Image)): - return cv2.cvtColor(img, cv2.COLOR_RGB2BGR) - return img - - if channel == 4: - return self.cvt_four_to_three(img) - - raise LoadImageError( - f"The channel({channel}) of the img is not in [1, 2, 3, 4]" - ) - - raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]") - - @staticmethod - def cvt_two_to_three(img: np.ndarray) -> np.ndarray: - """gray + alpha → BGR""" - img_gray = img[..., 0] - img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR) - - img_alpha = img[..., 1] - not_a = cv2.bitwise_not(img_alpha) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha) - new_img = cv2.add(new_img, not_a) - return new_img - - @staticmethod - def cvt_four_to_three(img: np.ndarray) -> np.ndarray: - """RGBA → BGR""" - r, g, b, a = cv2.split(img) - new_img = cv2.merge((b, g, r)) - - not_a = cv2.bitwise_not(a) - not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR) - - new_img = cv2.bitwise_and(new_img, new_img, mask=a) - - mean_color = np.mean(new_img) - if mean_color <= 0.0: - new_img = cv2.add(new_img, not_a) - else: - new_img = cv2.bitwise_not(new_img) - return new_img - - @staticmethod - def verify_exist(file_path: Union[str, Path]): - if not Path(file_path).exists(): - raise LoadImageError(f"{file_path} does not exist.") - - -class LoadImageError(Exception): - pass diff --git a/rapid_table_det/utils/logger.py b/rapid_table_det/utils/logger.py deleted file mode 100644 index 2950987..0000000 --- a/rapid_table_det/utils/logger.py +++ /dev/null @@ -1,21 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: Jocker1212 -# @Contact: xinyijianggo@gmail.com -import logging -from functools import lru_cache - - -@lru_cache(maxsize=32) -def get_logger(name: str) -> logging.Logger: - logger = logging.getLogger(name) - logger.setLevel(logging.DEBUG) - - fmt = "%(asctime)s - %(name)s - %(levelname)s: %(message)s" - format_str = logging.Formatter(fmt) - - sh = logging.StreamHandler() - sh.setLevel(logging.DEBUG) - - logger.addHandler(sh) - sh.setFormatter(format_str) - return logger diff --git a/rapid_table_det/utils/transform.py b/rapid_table_det/utils/transform.py deleted file mode 100644 index 3c1f845..0000000 --- a/rapid_table_det/utils/transform.py +++ /dev/null @@ -1,314 +0,0 @@ -import math -import itertools -import cv2 -import numpy as np - - -def generate_scale(im, resize_shape, keep_ratio): - """ - Args: - im (np.ndarray): image (np.ndarray) - Returns: - im_scale_x: the resize ratio of X - im_scale_y: the resize ratio of Y - """ - target_size = (resize_shape[0], resize_shape[1]) - # target_size = (800, 1333) - origin_shape = im.shape[:2] - - if keep_ratio: - im_size_min = np.min(origin_shape) - im_size_max = np.max(origin_shape) - target_size_min = np.min(target_size) - target_size_max = np.max(target_size) - im_scale = float(target_size_min) / float(im_size_min) - if np.round(im_scale * im_size_max) > target_size_max: - im_scale = float(target_size_max) / float(im_size_max) - im_scale_x = im_scale - im_scale_y = im_scale - else: - resize_h, resize_w = target_size - im_scale_y = resize_h / float(origin_shape[0]) - im_scale_x = resize_w / float(origin_shape[1]) - return im_scale_y, im_scale_x - - -def resize(im, im_info, resize_shape, keep_ratio, interp=2): - im_scale_y, im_scale_x = generate_scale(im, resize_shape, keep_ratio) - im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp) - im_info["im_shape"] = np.array(im.shape[:2]).astype("float32") - im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32") - - return im, im_info - - -def pad(im, im_info, resize_shape): - im_h, im_w = im.shape[:2] - fill_value = [114.0, 114.0, 114.0] - h, w = resize_shape[0], resize_shape[1] - if h == im_h and w == im_w: - im = im.astype(np.float32) - return im, im_info - - canvas = np.ones((h, w, 3), dtype=np.float32) - canvas *= np.array(fill_value, dtype=np.float32) - canvas[0:im_h, 0:im_w, :] = im.astype(np.float32) - im = canvas - return im, im_info - - -def ResizePad(img, target_size): - h, w = img.shape[:2] - m = max(h, w) - ratio = target_size / m - new_w, new_h = int(ratio * w), int(ratio * h) - img = cv2.resize(img, (new_w, new_h), cv2.INTER_LINEAR) - top = (target_size - new_h) // 2 - bottom = (target_size - new_h) - top - left = (target_size - new_w) // 2 - right = (target_size - new_w) - left - img1 = cv2.copyMakeBorder( - img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(114, 114, 114) - ) - return img1, new_w, new_h, left, top - - -def get_mini_boxes(contour): - bounding_box = cv2.minAreaRect(contour) - points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0]) - - index_1, index_2, index_3, index_4 = 0, 1, 2, 3 - if points[1][1] > points[0][1]: - index_1 = 0 - index_4 = 1 - else: - index_1 = 1 - index_4 = 0 - if points[3][1] > points[2][1]: - index_2 = 2 - index_3 = 3 - else: - index_2 = 3 - index_3 = 2 - - box = [points[index_1], points[index_2], points[index_3], points[index_4]] - return box, min(bounding_box[1]) - - -def minboundquad(hull): - len_hull = len(hull) - xy = np.array(hull).reshape([-1, 2]) - idx = np.arange(0, len_hull) - idx_roll = np.roll(idx, -1, axis=0) - edges = np.array([idx, idx_roll]).reshape([2, -1]) - edges = np.transpose(edges, [1, 0]) - edgeangles1 = [] - for i in range(len_hull): - y = xy[edges[i, 1], 1] - xy[edges[i, 0], 1] - x = xy[edges[i, 1], 0] - xy[edges[i, 0], 0] - angle = math.atan2(y, x) - if angle < 0: - angle = angle + 2 * math.pi - edgeangles1.append([angle, i]) - edgeangles1_idx = sorted(list(edgeangles1), key=lambda x: x[0]) - edges1 = [] - edgeangle1 = [] - for item in edgeangles1_idx: - idx = item[1] - edges1.append(edges[idx, :]) - edgeangle1.append(item[0]) - edgeangles = np.array(edgeangle1) - edges = np.array(edges1) - eps = 2.2204e-16 - angletol = eps * 100 - - k = np.diff(edgeangles) < angletol - idx = np.where(k == 1) - edges = np.delete(edges, idx, 0) - edgeangles = np.delete(edgeangles, idx, 0) - nedges = edges.shape[0] - edgelist = np.array(nchoosek(0, nedges - 1, 1, 4)) - k = edgeangles[edgelist[:, 3]] - edgeangles[edgelist[:, 0]] <= math.pi - k_idx = np.where(k == 1) - edgelist = np.delete(edgelist, k_idx, 0) - - nquads = edgelist.shape[0] - quadareas = math.inf - qxi = np.zeros([5]) - qyi = np.zeros([5]) - cnt = np.zeros([4, 1, 2]) - edgelist = list(edgelist) - edges = list(edges) - xy = list(xy) - - for i in range(nquads): - edgeind = list(edgelist[i]) - edgeind.append(edgelist[i][0]) - edgesi = [] - edgeang = [] - for idx in edgeind: - edgesi.append(edges[idx]) - edgeang.append(edgeangles[idx]) - is_continue = False - for idx in range(len(edgeang) - 1): - diff = edgeang[idx + 1] - edgeang[idx] - if diff > math.pi: - is_continue = True - if is_continue: - continue - for j in range(4): - jplus1 = j + 1 - shared = np.intersect1d(edgesi[j], edgesi[jplus1]) - if shared.size != 0: - qxi[j] = xy[shared[0]][0] - qyi[j] = xy[shared[0]][1] - else: - A = xy[edgesi[j][0]] - B = xy[edgesi[j][1]] - C = xy[edgesi[jplus1][0]] - D = xy[edgesi[jplus1][1]] - concat = np.hstack(((A - B).reshape([2, -1]), (D - C).reshape([2, -1]))) - div = (A - C).reshape([2, -1]) - inv_result = get_inv(concat) - a = inv_result[0, 0] - b = inv_result[0, 1] - c = inv_result[1, 0] - d = inv_result[1, 1] - e = div[0, 0] - f = div[1, 0] - ts1 = [a * e + b * f, c * e + d * f] - Q = A + (B - A) * ts1[0] - qxi[j] = Q[0] - qyi[j] = Q[1] - - contour = np.array([qxi[:4], qyi[:4]]).astype(np.int32) - contour = np.transpose(contour, [1, 0]) - contour = contour[:, np.newaxis, :] - A_i = cv2.contourArea(contour) - # break - - if A_i < quadareas: - quadareas = A_i - cnt = contour - return cnt - - -def nchoosek(startnum, endnum, step=1, n=1): - c = [] - for i in itertools.combinations(range(startnum, endnum + 1, step), n): - c.append(list(i)) - return c - - -def get_inv(concat): - a = concat[0][0] - b = concat[0][1] - c = concat[1][0] - d = concat[1][1] - det_concat = a * d - b * c - inv_result = np.array( - [[d / det_concat, -b / det_concat], [-c / det_concat, a / det_concat]] - ) - return inv_result - - -def get_max_adjacent_bbox(mask): - contours, _ = cv2.findContours( - (mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE - ) - max_size = 0 - cnt_save = None - # 找到最大边缘邻接矩形 - for cont in contours: - points, sside = get_mini_boxes(cont) - if sside > max_size: - max_size = sside - cnt_save = cont - if cnt_save is not None: - epsilon = 0.01 * cv2.arcLength(cnt_save, True) - box = cv2.approxPolyDP(cnt_save, epsilon, True) - hull = cv2.convexHull(box) - points, sside = get_mini_boxes(cnt_save) - len_hull = len(hull) - - if len_hull == 4: - target_box = np.array(hull) - elif len_hull > 4: - target_box = minboundquad(hull) - else: - target_box = np.array(points) - - return np.array(target_box).reshape([-1, 2]) - - -def sigmoid(x): - return 1 / (1 + np.exp(-x)) - - -def calculate_iou(box, other_boxes): - """ - 计算给定边界框与一组其他边界框之间的交并比(IoU)。 - - 参数: - - box: 单个边界框,格式为 [x1, y1, width, height]。 - - other_boxes: 其他边界框的数组,每个边界框的格式也为 [x1, y1, width, height]。 - - 返回值: - - iou: 一个数组,包含给定边界框与每个其他边界框的IoU值。 - """ - - # 计算交集的左上角坐标 - x1 = np.maximum(box[0], np.array(other_boxes)[:, 0]) - y1 = np.maximum(box[1], np.array(other_boxes)[:, 1]) - # 计算交集的右下角坐标 - x2 = np.minimum(box[2], np.array(other_boxes)[:, 2]) - y2 = np.minimum(box[3], np.array(other_boxes)[:, 3]) - # 计算交集区域的面积 - intersection_area = np.maximum(0, x2 - x1) * np.maximum(0, y2 - y1) - # 计算给定边界框的面积 - box_area = (box[2] - box[0]) * (box[3] - box[1]) - # 计算其他边界框的面积 - other_boxes_area = np.array(other_boxes[:, 2] - other_boxes[:, 0]) * np.array( - other_boxes[:, 3] - other_boxes[:, 1] - ) - # 计算IoU值 - iou = intersection_area / (box_area + other_boxes_area - intersection_area) - return iou - - -def custom_NMSBoxes(boxes, scores, iou_threshold=0.4): - # 如果没有边界框,则直接返回空列表 - if len(boxes) == 0: - return [] - # 将得分和边界框转换为NumPy数组 - scores = np.array(scores) - boxes = np.array(boxes) - # 根据置信度阈值过滤边界框 - # filtered_boxes = boxes[mask] - # filtered_scores = scores[mask] - # 如果过滤后没有边界框,则返回空列表 - if len(boxes) == 0: - return [] - # 根据置信度得分对边界框进行排序 - sorted_indices = np.argsort(scores)[::-1] - # 初始化一个空列表来存储选择的边界框索引 - indices = [] - # 当还有未处理的边界框时,循环继续 - while len(sorted_indices) > 0: - # 选择得分最高的边界框索引 - current_index = sorted_indices[0] - indices.append(current_index) - # 如果只剩一个边界框,则结束循环 - if len(sorted_indices) == 1: - break - # 获取当前边界框和其他边界框 - current_box = boxes[current_index] - other_boxes = boxes[sorted_indices[1:]] - # 计算当前边界框与其他边界框的IoU - iou = calculate_iou(current_box, other_boxes) - # 找到IoU低于阈值的边界框,即与当前边界框不重叠的边界框 - non_overlapping_indices = np.where(iou <= iou_threshold)[0] - # 更新sorted_indices以仅包含不重叠的边界框 - sorted_indices = sorted_indices[non_overlapping_indices + 1] - # 返回选择的边界框索引 - return indices diff --git a/rapid_table_det/utils/visuallize.py b/rapid_table_det/utils/visuallize.py deleted file mode 100644 index fb346cc..0000000 --- a/rapid_table_det/utils/visuallize.py +++ /dev/null @@ -1,65 +0,0 @@ -import cv2 - -from rapid_table_det.utils.load_image import LoadImage -import numpy as np - -img_loader = LoadImage() - - -def visuallize(img, box, lt, rt, rb, lb): - xmin, ymin, xmax, ymax = box - draw_box = np.array([lt, rt, rb, lb]).reshape([-1, 2]) - cv2.circle(img, (int(lt[0]), int(lt[1])), 50, (255, 0, 0), 10) - cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 10) - cv2.polylines( - img, - [np.array(draw_box).astype(np.int32).reshape((-1, 1, 2))], - True, - color=(255, 0, 255), - thickness=6, - ) - return img - - -def extract_table_img(img, lt, rt, rb, lb): - """ - 根据四个角点进行透视变换,并提取出角点区域的图片。 - - 参数: - img (numpy.ndarray): 输入图像 - lt (numpy.ndarray): 左上角坐标 - rt (numpy.ndarray): 右上角坐标 - lb (numpy.ndarray): 左下角坐标 - rb (numpy.ndarray): 右下角坐标 - - 返回: - numpy.ndarray: 提取出的角点区域图片 - """ - # 源点坐标 - src_points = np.float32([lt, rt, lb, rb]) - - # 目标点坐标 - width_a = np.sqrt(((rb[0] - lb[0]) ** 2) + ((rb[1] - lb[1]) ** 2)) - width_b = np.sqrt(((rt[0] - lt[0]) ** 2) + ((rt[1] - lt[1]) ** 2)) - max_width = max(int(width_a), int(width_b)) - - height_a = np.sqrt(((rt[0] - rb[0]) ** 2) + ((rt[1] - rb[1]) ** 2)) - height_b = np.sqrt(((lt[0] - lb[0]) ** 2) + ((lt[1] - lb[1]) ** 2)) - max_height = max(int(height_a), int(height_b)) - - dst_points = np.float32( - [ - [0, 0], - [max_width - 1, 0], - [0, max_height - 1], - [max_width - 1, max_height - 1], - ] - ) - - # 获取透视变换矩阵 - M = cv2.getPerspectiveTransform(src_points, dst_points) - - # 应用透视变换 - warped = cv2.warpPerspective(img, M, (max_width, max_height)) - - return warped diff --git a/rapid_table_det_paddle/__init__.py b/rapid_table_det_paddle/__init__.py deleted file mode 100644 index dcd77e7..0000000 --- a/rapid_table_det_paddle/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: Jocker1212 -# @Contact: xinyijianggo@gmail.com -from .inference import TableDetector -from .utils import img_loader, visuallize, extract_table_img - -__all__ = ["TableDetector", "img_loader", "visuallize", "extract_table_img"] diff --git a/rapid_table_det_paddle/inference.py b/rapid_table_det_paddle/inference.py deleted file mode 100644 index 5393847..0000000 --- a/rapid_table_det_paddle/inference.py +++ /dev/null @@ -1,143 +0,0 @@ -import cv2 -import numpy as np - -from rapid_table_det_paddle.predictor import DbNet, ObjectDetector, PPLCNet -from rapid_table_det_paddle.utils import LoadImage - - -class TableDetector: - def __init__( - self, - edge_model_path=None, - obj_model_path=None, - cls_model_path=None, - use_obj_det=True, - use_edge_det=True, - use_cls_det=True, - ): - self.use_obj_det = use_obj_det - self.use_edge_det = use_edge_det - self.use_cls_det = use_cls_det - self.img_loader = LoadImage() - if self.use_obj_det: - self.obj_detector = ObjectDetector(obj_model_path) - if self.use_edge_det: - self.dbnet = DbNet(edge_model_path) - if self.use_cls_det: - self.pplcnet = PPLCNet(cls_model_path) - - def __call__(self, img, det_accuracy=0.7): - img = self.img_loader(img) - img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) - img_mask = img.copy() - h, w = img.shape[:-1] - img_box = np.array([0, 0, w, h]) - lb, lt, rb, rt = self.get_box_points(img_box) - # 初始化默认值 - obj_det_res, edge_box, pred_label = ( - [[1.0, img_box]], - img_box.reshape([-1, 2]), - 0, - ) - result = [] - obj_det_elapse, edge_elapse, rotate_det_elapse = 0, 0, 0 - if self.use_obj_det: - obj_det_res, obj_det_elapse = self.obj_detector(img, score=det_accuracy) - for i in range(len(obj_det_res)): - det_res = obj_det_res[i] - score, box = det_res - xmin, ymin, xmax, ymax = box - edge_box = box.reshape([-1, 2]) - lb, lt, rb, rt = self.get_box_points(box) - if self.use_edge_det: - xmin_edge, ymin_edge, xmax_edge, ymax_edge = self.pad_box_points( - h, w, xmax, xmin, ymax, ymin, 10 - ) - crop_img = img_mask[ymin_edge:ymax_edge, xmin_edge:xmax_edge, :] - edge_box, lt, lb, rt, rb, tmp_edge_elapse = self.dbnet(crop_img) - edge_elapse += tmp_edge_elapse - if edge_box is None: - continue - edge_box[:, 0] += xmin_edge - edge_box[:, 1] += ymin_edge - lt, lb, rt, rb = ( - lt + [xmin_edge, ymin_edge], - lb + [xmin_edge, ymin_edge], - rt + [xmin_edge, ymin_edge], - rb + [xmin_edge, ymin_edge], - ) - if self.use_cls_det: - xmin_cls, ymin_cls, xmax_cls, ymax_cls = self.pad_box_points( - h, w, xmax, xmin, ymax, ymin, 5 - ) - cls_box = edge_box.copy() - cls_img = img_mask[ymin_cls:ymax_cls, xmin_cls:xmax_cls, :] - cls_box[:, 0] = cls_box[:, 0] - xmin_cls - cls_box[:, 1] = cls_box[:, 1] - ymin_cls - # 画框增加先验信息,辅助方向label识别 - cv2.polylines( - cls_img, - [np.array(cls_box).astype(np.int32).reshape((-1, 1, 2))], - True, - color=(255, 0, 255), - thickness=5, - ) - pred_label, tmp_rotate_det_elapse = self.pplcnet(cls_img) - rotate_det_elapse += tmp_rotate_det_elapse - lb1, lt1, rb1, rt1 = self.get_real_rotated_points( - lb, lt, pred_label, rb, rt - ) - result.append( - { - "box": [int(xmin), int(ymin), int(xmax), int(ymax)], - "lb": [int(lb1[0]), int(lb1[1])], - "lt": [int(lt1[0]), int(lt1[1])], - "rt": [int(rt1[0]), int(rt1[1])], - "rb": [int(rb1[0]), int(rb1[1])], - } - ) - elapse = [obj_det_elapse, edge_elapse, rotate_det_elapse] - return result, elapse - - def get_box_points(self, img_box): - x1, y1, x2, y2 = img_box - lt = np.array([x1, y1]) # 左上角 - rt = np.array([x2, y1]) # 右上角 - rb = np.array([x2, y2]) # 右下角 - lb = np.array([x1, y2]) # 左下角 - return lb, lt, rb, rt - - def get_real_rotated_points(self, lb, lt, pred_label, rb, rt): - if pred_label == 0: - lt1 = lt - rt1 = rt - rb1 = rb - lb1 = lb - elif pred_label == 1: - lt1 = rt - rt1 = rb - rb1 = lb - lb1 = lt - elif pred_label == 2: - lt1 = rb - rt1 = lb - rb1 = lt - lb1 = rt - elif pred_label == 3: - lt1 = lb - rt1 = lt - rb1 = rt - lb1 = rb - else: - lt1 = lt - rt1 = rt - rb1 = rb - lb1 = lb - return lb1, lt1, rb1, rt1 - - def pad_box_points(self, h, w, xmax, xmin, ymax, ymin, pad): - ymin_edge = max(ymin - pad, 0) - xmin_edge = max(xmin - pad, 0) - ymax_edge = min(ymax + pad, h) - xmax_edge = min(xmax + pad, w) - return xmin_edge, ymin_edge, xmax_edge, ymax_edge diff --git a/rapid_table_det_paddle/models/.gitkeep b/rapid_table_det_paddle/models/.gitkeep deleted file mode 100644 index e69de29..0000000 diff --git a/rapid_table_det_paddle/requirments.txt b/rapid_table_det_paddle/requirments.txt deleted file mode 100644 index 9875f90..0000000 --- a/rapid_table_det_paddle/requirments.txt +++ /dev/null @@ -1,6 +0,0 @@ -paddlepaddle -numpy -Pillow -opencv-python -onnxruntime -requests diff --git a/readme_resource/res_show.jpg b/readme_resource/res_show.jpg deleted file mode 100644 index 34210e5..0000000 Binary files a/readme_resource/res_show.jpg and /dev/null differ diff --git a/readme_resource/res_show2.jpg b/readme_resource/res_show2.jpg deleted file mode 100644 index 119ac71..0000000 Binary files a/readme_resource/res_show2.jpg and /dev/null differ diff --git a/readme_resource/structure.png b/readme_resource/structure.png deleted file mode 100644 index da053c2..0000000 Binary files a/readme_resource/structure.png and /dev/null differ diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 5ef9652..0000000 --- a/requirements.txt +++ /dev/null @@ -1,6 +0,0 @@ -numpy -tqdm -Pillow -opencv-python -onnxruntime -requests diff --git a/setup_rapid_table_det.py b/setup_rapid_table_det.py deleted file mode 100644 index 52f66af..0000000 --- a/setup_rapid_table_det.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: Jocker1212 -# @Contact: xinyijianggo@gmail.com -import sys -from typing import List, Union -from pathlib import Path -from get_pypi_latest_version import GetPyPiLatestVersion - -import setuptools - - -def read_txt(txt_path: Union[Path, str]) -> List[str]: - with open(txt_path, "r", encoding="utf-8") as f: - data = [v.rstrip("\n") for v in f] - return data - - -MODULE_NAME = "rapid_table_det" - -obtainer = GetPyPiLatestVersion() -try: - latest_version = obtainer(MODULE_NAME) -except Exception: - latest_version = "0.0.0" - -VERSION_NUM = obtainer.version_add_one(latest_version) - -if len(sys.argv) > 2: - match_str = " ".join(sys.argv[2:]) - matched_versions = obtainer.extract_version(match_str) - if matched_versions: - VERSION_NUM = matched_versions -sys.argv = sys.argv[:2] - -setuptools.setup( - name=MODULE_NAME, - version=VERSION_NUM, - platforms="Any", - description="table detection with onnx model", - long_description="table detection with onnx model", - author="jockerK", - author_email="xinyijianggo@gmail.com", - url="https://github.com/Joker1212/RapidTableDetection", - license="Apache-2.0", - install_requires=read_txt("requirements.txt"), - include_package_data=False, - packages=[MODULE_NAME, f"{MODULE_NAME}.models", f"{MODULE_NAME}.utils"], - package_data={"": [".gitkeep"]}, - keywords=["obj detection,ocr,table-recognition"], - classifiers=[ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - python_requires=">=3.8,<3.13", -) diff --git a/setup_rapid_table_det_paddle.py b/setup_rapid_table_det_paddle.py deleted file mode 100644 index 3be1fe8..0000000 --- a/setup_rapid_table_det_paddle.py +++ /dev/null @@ -1,56 +0,0 @@ -# -*- encoding: utf-8 -*- -# @Author: Jocker1212 -# @Contact: xinyijianggo@gmail.com -import sys -from typing import List, Union -from pathlib import Path -from get_pypi_latest_version import GetPyPiLatestVersion - -import setuptools - - -def read_txt(txt_path: Union[Path, str]) -> List[str]: - with open(txt_path, "r", encoding="utf-8") as f: - data = [v.rstrip("\n") for v in f] - return data - - -MODULE_NAME = "rapid_table_det_paddle" - -obtainer = GetPyPiLatestVersion() -try: - latest_version = obtainer(MODULE_NAME) -except Exception: - latest_version = "0.0.0" - -VERSION_NUM = obtainer.version_add_one(latest_version) - -if len(sys.argv) > 2: - match_str = " ".join(sys.argv[2:]) - matched_versions = obtainer.extract_version(match_str) - if matched_versions: - VERSION_NUM = matched_versions -sys.argv = sys.argv[:2] - -setuptools.setup( - name=MODULE_NAME, - version=VERSION_NUM, - platforms="Any", - description="table detection with original paddle model", - long_description="table detection with original paddle model", - author="jockerK", - author_email="xinyijianggo@gmail.com", - url="https://github.com/Joker1212/RapidTableDetection", - license="Apache-2.0", - install_requires=read_txt("requirements.txt"), - include_package_data=True, - packages=[MODULE_NAME, f"{MODULE_NAME}.models"], - keywords=["obj detection,ocr,table-recognition"], - classifiers=[ - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - ], - python_requires=">=3.8,<3.13", -) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/test_files/chip.jpg b/tests/test_files/chip.jpg deleted file mode 100644 index f87c153..0000000 Binary files a/tests/test_files/chip.jpg and /dev/null differ diff --git a/tests/test_files/chip2.jpg b/tests/test_files/chip2.jpg deleted file mode 100644 index 8a704ee..0000000 Binary files a/tests/test_files/chip2.jpg and /dev/null differ diff --git a/tests/test_files/doc.png b/tests/test_files/doc.png deleted file mode 100644 index 54d302c..0000000 Binary files a/tests/test_files/doc.png and /dev/null differ diff --git a/tests/test_table_det.py b/tests/test_table_det.py deleted file mode 100644 index 24c1358..0000000 --- a/tests/test_table_det.py +++ /dev/null @@ -1,24 +0,0 @@ -import sys -from pathlib import Path - -import pytest - - -cur_dir = Path(__file__).resolve().parent -root_dir = cur_dir.parent - -sys.path.append(str(root_dir)) -test_file_dir = cur_dir / "test_files" - - -@pytest.mark.parametrize( - "img_path, expected", - [("chip.jpg", 1), ("doc.png", 2)], -) -def test_input_normal(img_path, expected): - from rapid_table_det import TableDetector - - table_det = TableDetector() - img_path = test_file_dir / img_path - result, elapse = table_det(img_path) - assert len(result) == expected diff --git a/tests/test_table_det_paddle.py b/tests/test_table_det_paddle.py deleted file mode 100644 index 41e49fc..0000000 --- a/tests/test_table_det_paddle.py +++ /dev/null @@ -1,30 +0,0 @@ -import sys -from pathlib import Path - -import pytest - -cur_dir = Path(__file__).resolve().parent -root_dir = cur_dir.parent - -sys.path.append(str(root_dir)) -test_file_dir = cur_dir / "test_files" - - -@pytest.mark.parametrize( - "img_path, expected", - [("chip.jpg", 1), ("doc.png", 2)], -) -def test_input_normal(img_path, expected): - from rapid_table_det_paddle.inference import TableDetector - - table_det = TableDetector( - obj_model_path=f"{root_dir}/rapid_table_det_paddle/models/obj_det_paddle", - edge_model_path=f"{root_dir}/rapid_table_det_paddle/models/edge_det_paddle", - cls_model_path=f"{root_dir}/rapid_table_det_paddle/models/cls_det_paddle", - use_obj_det=True, - use_edge_det=True, - use_cls_det=True, - ) - img_path = test_file_dir / img_path - result, elapse = table_det(img_path) - assert len(result) == expected diff --git a/tools/__init__.py b/tools/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tools/fix_onnx.py b/tools/fix_onnx.py deleted file mode 100644 index 5b6844d..0000000 --- a/tools/fix_onnx.py +++ /dev/null @@ -1,75 +0,0 @@ -import onnx -from onnx import helper - - -def create_constant_tensor(name, value): - tensor = helper.make_tensor( - name=name, data_type=onnx.TensorProto.INT64, dims=[len(value)], vals=value - ) - return tensor - - -# 创建新的 squeeze 节点 -def create_squeeze_node(name, input_name, axes): - const_tensor = create_constant_tensor(name + "_constant", axes) - squeeze_node = helper.make_node( - "Squeeze", - inputs=[input_name, name + "_constant"], - outputs=[input_name + "_squeezed"], - name=name + "_Squeeze", - ) - return const_tensor, squeeze_node - - -def fix_onnx(model_path): - model = onnx.load(model_path) - - # 删除指定的节点 - deleted_nodes = ["p2o.Squeeze.3", "p2o.Squeeze.5"] - nodes_to_delete = [node for node in model.graph.node if node.name in deleted_nodes] - for node in nodes_to_delete: - model.graph.node.remove(node) - # 找到 gather8 和 gather10 的输出 - gather8_output = None - gather10_output = None - # 找到 'p2o.Gather.1' 节点 - for node in model.graph.node: - if node.name == "p2o.Gather.0": - gather_output_name = node.output[0] - # break - if node.name == "p2o.Gather.2": - gather_output_name1 = node.output[0] - # break - - for node in model.graph.node: - if node.name == "p2o.Gather.8": - node.input[0] = gather_output_name - gather8_output = node.output[0] - elif node.name == "p2o.Gather.10": - node.input[0] = gather_output_name1 - gather10_output = node.output[0] - - if gather8_output: - new_squeeze_components_8 = create_squeeze_node( - "p2o.Gather.8", gather8_output, [1] - ) - model.graph.initializer.append(new_squeeze_components_8[0]) # 添加常量张量 - model.graph.node.append(new_squeeze_components_8[1]) # 添加 Squeeze 节点 - - if gather10_output: - new_squeeze_components_10 = create_squeeze_node( - "p2o.Gather.10", gather10_output, [1] - ) - model.graph.initializer.append(new_squeeze_components_10[0]) # 添加常量张量 - model.graph.node.append(new_squeeze_components_10[1]) # 添加 Squeeze 节点 - - # 更新依赖于 gather8 和 gather10 的节点输入 - for node in model.graph.node: - if node.name == "p2o.Cast.0": - node.input[0] = new_squeeze_components_8[1].output[0] - - if node.name == "p2o.Gather.12": - node.input[1] = new_squeeze_components_10[1].output[0] - - # 保存修改后的模型 - onnx.save(model, model_path) diff --git a/tools/onnx_transform.ipynb b/tools/onnx_transform.ipynb deleted file mode 100644 index 9452a43..0000000 --- a/tools/onnx_transform.ipynb +++ /dev/null @@ -1,156 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "source": [ - "!pip install paddle2onnx onnxruntime onnxslim onnxruntime-tools onnx pickleshare -i https://pypi.tuna.tsinghua.edu.cn/simple" - ], - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - }, - "ExecuteTime": { - "end_time": "2024-10-19T13:58:51.289307Z", - "start_time": "2024-10-19T13:58:31.510101Z" - } - }, - "source": [ - "!paddle2onnx --model_dir ../rapid_table_det_paddle/models --model_filename obj_det_paddle.pdmodel --params_filename obj_det_paddle.pdiparams --save_file ../rapid_table_det/models/obj_det.onnx --opset_version 16 --enable_onnx_checker True\n", - "!paddle2onnx --model_dir ../rapid_table_det_paddle/models --model_filename edge_det_paddle.pdmodel --params_filename edge_det_paddle.pdiparams --save_file ../rapid_table_det/models/edge_det.onnx --opset_version 16 --enable_onnx_checker True\n", - "!paddle2onnx --model_dir ../rapid_table_det_paddle/models --model_filename cls_det_paddle.pdmodel --params_filename cls_det_paddle.pdiparams --save_file ../rapid_table_det/models/cls_det.onnx --opset_version 16 --enable_onnx_checker True\n", - "\n", - "!onnxslim ../rapid_table_det/models/obj_det.onnx ../rapid_table_det/models/obj_det.onnx\n", - "!onnxslim ../rapid_table_det/models/edge_det.onnx ../rapid_table_det/models/edge_det.onnx\n", - "!onnxslim ../rapid_table_det/models/cls_det.onnx ../rapid_table_det/models/cls_det.onnx" - ], - "execution_count": 1, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - }, - "ExecuteTime": { - "end_time": "2024-10-19T13:58:56.174983Z", - "start_time": "2024-10-19T13:58:55.580038Z" - } - }, - "source": [ - "from pathlib import Path\n", - "import onnx\n", - "from onnxruntime.quantization import quantize_dynamic, QuantType, quant_pre_process\n", - "def quantize_model(root_dir_str, model_dir, pre_fix):\n", - "\n", - " original_model_path = f\"{pre_fix}.onnx\"\n", - " quantized_model_path = f\"{pre_fix}_quantized.onnx\"\n", - " # quantized_model_path = original_model_path\n", - " original_model_path = f\"{root_dir_str}/{model_dir}/{original_model_path}\"\n", - " quantized_model_path = f\"{root_dir_str}/{model_dir}/{quantized_model_path}\"\n", - " quant_pre_process(original_model_path, quantized_model_path, auto_merge=True)\n", - " # 进行动态量化\n", - " quantize_dynamic(\n", - " model_input=quantized_model_path,\n", - " model_output=quantized_model_path,\n", - " weight_type=QuantType.QUInt8\n", - " )\n", - "\n", - " # 验证量化后的模型\n", - " quantized_model = onnx.load(quantized_model_path)\n", - " onnx.checker.check_model(quantized_model)\n", - " print(\"Quantized model is valid.\")" - ], - "execution_count": 2, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "collapsed": false, - "jupyter": { - "outputs_hidden": false - }, - "pycharm": { - "name": "#%%\n" - }, - "ExecuteTime": { - "end_time": "2024-10-19T13:59:14.149803Z", - "start_time": "2024-10-19T13:58:59.542092Z" - } - }, - "source": [ - "root_dir_str = \"..\"\n", - "model_dir = f\"rapid_table_det/models\"\n", - "quantize_model(root_dir_str, model_dir, \"obj_det\")\n", - "quantize_model(root_dir_str, model_dir, \"edge_det\")" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "ExecuteTime": { - "end_time": "2024-10-19T13:59:19.984452Z", - "start_time": "2024-10-19T13:59:18.181521Z" - } - }, - "source": [ - "from fix_onnx import fix_onnx\n", - "import os\n", - "# 指定目录路径\n", - "model_dir = \"../rapid_table_det/models\"\n", - "# 加载现有 ONNX 模型\n", - "model_path = os.path.join(model_dir, \"obj_det.onnx\")\n", - "fix_onnx(model_path)\n", - "model_path = os.path.join(model_dir, \"obj_det_quantized.onnx\")\n", - "fix_onnx(model_path)" - ], - "execution_count": 4, - "outputs": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "source": [], - "outputs": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.14" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/rapid_table_det_paddle/utils.py b/utils.py similarity index 90% rename from rapid_table_det_paddle/utils.py rename to utils.py index 2aff817..e13483d 100644 --- a/rapid_table_det_paddle/utils.py +++ b/utils.py @@ -1,7 +1,7 @@ import math from io import BytesIO from pathlib import Path -from typing import Union +from typing import List, Union import itertools import cv2 import numpy as np @@ -13,7 +13,7 @@ class LoadImage: def __init__( - self, + self, ): pass @@ -102,9 +102,6 @@ def verify_exist(file_path: Union[str, Path]): raise LoadImageError(f"{file_path} does not exist.") -img_loader = LoadImage() - - class LoadImageError(Exception): pass @@ -140,9 +137,16 @@ def generate_scale(im, resize_shape, keep_ratio): def resize(im, im_info, resize_shape, keep_ratio, interp=2): im_scale_y, im_scale_x = generate_scale(im, resize_shape, keep_ratio) - im = cv2.resize(im, None, None, fx=im_scale_x, fy=im_scale_y, interpolation=interp) - im_info["im_shape"] = np.array(im.shape[:2]).astype("float32") - im_info["scale_factor"] = np.array([im_scale_y, im_scale_x]).astype("float32") + im = cv2.resize( + im, + None, + None, + fx=im_scale_x, + fy=im_scale_y, + interpolation=interp) + im_info['im_shape'] = np.array(im.shape[:2]).astype('float32') + im_info['scale_factor'] = np.array( + [im_scale_y, im_scale_x]).astype('float32') return im, im_info @@ -172,9 +176,7 @@ def ResizePad(img, target_size): bottom = (target_size - new_h) - top left = (target_size - new_w) // 2 right = (target_size - new_w) - left - img1 = cv2.copyMakeBorder( - img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(255, 255, 255) - ) + img1 = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=(255, 255, 255)) return img1, new_w, new_h, left, top @@ -196,7 +198,9 @@ def get_mini_boxes(contour): index_2 = 3 index_3 = 2 - box = [points[index_1], points[index_2], points[index_3], points[index_4]] + box = [ + points[index_1], points[index_2], points[index_3], points[index_4] + ] return box, min(bounding_box[1]) @@ -311,16 +315,13 @@ def get_inv(concat): c = concat[1][0] d = concat[1][1] det_concat = a * d - b * c - inv_result = np.array( - [[d / det_concat, -b / det_concat], [-c / det_concat, a / det_concat]] - ) + inv_result = np.array([[d / det_concat, -b / det_concat], + [-c / det_concat, a / det_concat]]) return inv_result def get_max_adjacent_bbox(mask): - contours, _ = cv2.findContours( - (mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE - ) + contours, _ = cv2.findContours((mask * 255).astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) max_size = 0 cnt_save = None # 找到最大边缘邻接矩形 @@ -346,22 +347,16 @@ def get_max_adjacent_bbox(mask): return np.array(target_box).reshape([-1, 2]) -def visuallize(img, box, lt, rt, rb, lb): +def visuallize(img, box, lt, rt, lb, rb, save_path): xmin, ymin, xmax, ymax = box draw_box = np.array([lt, rt, rb, lb]).reshape([-1, 2]) cv2.circle(img, (int(lt[0]), int(lt[1])), 50, (255, 0, 0), 10) cv2.rectangle(img, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 0, 0), 10) - cv2.polylines( - img, - [np.array(draw_box).astype(np.int32).reshape((-1, 1, 2))], - True, - color=(255, 0, 255), - thickness=6, - ) - return img - + cv2.polylines(img, [np.array(draw_box).astype(np.int32).reshape((-1, 1, 2))], True, + color=(255, 0, 255), thickness=6) + cv2.imwrite(save_path, img) -def extract_table_img(img, lt, rt, rb, lb): +def extract_table_img(img, lt, rt, lb, rb): """ 根据四个角点进行透视变换,并提取出角点区域的图片。 @@ -387,18 +382,14 @@ def extract_table_img(img, lt, rt, rb, lb): height_b = np.sqrt(((lt[0] - lb[0]) ** 2) + ((lt[1] - lb[1]) ** 2)) max_height = max(int(height_a), int(height_b)) - dst_points = np.float32( - [ - [0, 0], - [max_width - 1, 0], - [0, max_height - 1], - [max_width - 1, max_height - 1], - ] - ) + dst_points = np.float32([[0, 0], [max_width - 1, 0], [0, max_height - 1], [max_width - 1, max_height - 1]]) # 获取透视变换矩阵 M = cv2.getPerspectiveTransform(src_points, dst_points) # 应用透视变换 warped = cv2.warpPerspective(img, M, (max_width, max_height)) + return warped + +