-
Notifications
You must be signed in to change notification settings - Fork 76
/
Copy pathpackage_extractor.py
117 lines (101 loc) · 3.99 KB
/
package_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import structlog
import tree_sitter_go as tsgo
import tree_sitter_java as tsjava
import tree_sitter_javascript as tsjavascript
import tree_sitter_python as tspython
import tree_sitter_rust as tsrust
from tree_sitter import Language, Parser
logger = structlog.get_logger("codegate")
class PackageExtractor:
__languages = {
"javascript": Language(tsjavascript.language()),
"go": Language(tsgo.language()),
"python": Language(tspython.language()),
"java": Language(tsjava.language()),
"rust": Language(tsrust.language()),
}
__parsers = {
"javascript": Parser(__languages["javascript"]),
"go": Parser(__languages["go"]),
"python": Parser(__languages["python"]),
"java": Parser(__languages["java"]),
"rust": Parser(__languages["rust"]),
}
__queries = {
"javascript": """
(import_statement
source: (string) @import_name)
(call_expression
function: (identifier) @require
arguments: (arguments (string) @import_name)
(#eq? @require "require")
)
""",
"go": """
(import_declaration
(import_spec
(interpreted_string_literal) @import_name
)
)
(import_declaration
(import_spec_list
(import_spec
(interpreted_string_literal) @import_name
)
)
)
""",
"python": """
(import_statement
name: (dotted_name) @import_name)
(import_from_statement
module_name: (dotted_name) @import_name)
(import_statement
(aliased_import (dotted_name) @import_name (identifier)))
""",
"java": """
(import_declaration
(scoped_identifier) @import_name)
""",
"rust": """
(use_declaration
(scoped_identifier) @import_name)
(use_declaration
(identifier) @import_name)
(use_declaration
(use_wildcard) @import_name)
(use_declaration
(use_as_clause (scoped_identifier) @import_name))
""",
}
@staticmethod
def extract_packages(code: str, language_name: str) -> list[str]:
if (code is None) or (language_name is None):
return []
language_name = language_name.lower()
if language_name not in PackageExtractor.__languages.keys():
return []
language = PackageExtractor.__languages[language_name]
parser = PackageExtractor.__parsers[language_name]
# Create tree
tree = parser.parse(bytes(code, "utf8"))
# Create query for imports
query = language.query(PackageExtractor.__queries[language_name])
# Execute query
all_captures = query.captures(tree.root_node)
# Collect imports
imports = set()
for capture_name, captures in all_captures.items():
if capture_name != "import_name":
continue
for capture in captures:
import_lib = code[capture.start_byte : capture.end_byte]
# Remove quotes from the import string
import_lib = import_lib.strip("'\"")
# Get the root library name
if language_name == "python":
import_lib = import_lib.split(".")[0]
if language_name == "rust":
import_lib = import_lib.split("::")[0]
imports.add(import_lib)
return list(imports)