diff --git a/README.md b/README.md index 160b738..ba19785 100644 --- a/README.md +++ b/README.md @@ -210,7 +210,7 @@ results.vis(save_dir="outputs", save_name="vis") #### πŸ“¦ 终端运葌 ```bash -rapid_table test_images/table.jpg -v +rapid_table https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg -v ``` ### πŸ“ η»“ζžœ diff --git a/rapid_table/main.py b/rapid_table/main.py index 8bd6fd7..0f37b6a 100644 --- a/rapid_table/main.py +++ b/rapid_table/main.py @@ -109,7 +109,7 @@ def get_table_matcher(self, pred_structures, cell_bboxes, dt_boxes, rec_res): def parse_args(arg_list: Optional[List[str]] = None): parser = argparse.ArgumentParser() - parser.add_argument("img_path", type=Path, help="Path to image for layout.") + parser.add_argument("img_path", type=str, help="the image path or URL of the table") parser.add_argument( "-m", "--model_type", @@ -145,8 +145,8 @@ def main(arg_list: Optional[List[str]] = None): print(table_results.pred_html) if args.vis: - save_dir = img_path.resolve().parent - table_results.vis(save_dir, save_name=img_path.stem) + save_dir = Path(img_path).resolve().parent + table_results.vis(save_dir, save_name=Path(img_path).stem) if __name__ == "__main__": diff --git a/rapid_table/table_structure/__init__.py b/rapid_table/table_structure/__init__.py index 7a6c8bd..1d603d8 100644 --- a/rapid_table/table_structure/__init__.py +++ b/rapid_table/table_structure/__init__.py @@ -1,5 +1,4 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com -from .pp_structure import PPTableStructurer -from .unitable import UniTableStructure + diff --git a/tests/test_main.py b/tests/test_main.py index 3d8f046..2c20463 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -20,6 +20,7 @@ test_file_dir = cur_dir / "test_files" img_path = str(test_file_dir / "table.jpg") +img_url = "https://raw.githubusercontent.com/RapidAI/RapidTable/refs/heads/main/tests/test_files/table.jpg" def test_only_table(): @@ -41,9 +42,12 @@ def test_without_txt_table(): @pytest.mark.parametrize( "command, expected_output", - [(f"{img_path} --model_type slanet_plus", 1274)], + [ + (f"{img_path} --model_type slanet_plus", 1274), + (f"{img_url} --model_type slanet_plus", 1274), + ], ) -def test_main(capsys, command, expected_output): +def test_main_cli(capsys, command, expected_output): main(shlex.split(command)) output = capsys.readouterr().out.rstrip() assert len(output) == expected_output