Skip to content

Adding more support for code parsing #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cldk/analysis/java/java.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from networkx import DiGraph

from cldk.analysis import SymbolTable, CallGraph, AnalysisLevel
from cldk.analysis.java.treesitter import JavaSitter
from cldk.models.java import JCallable
from cldk.models.java import JApplication
from cldk.models.java.models import JCompilationUnit, JMethodDetail, JType, JField
Expand Down Expand Up @@ -175,6 +176,28 @@ def get_class_hierarchy(self) -> DiGraph:
raise NotImplementedError(f"Support for this functionality has not been implemented yet.")
raise NotImplementedError("Class hierarchy is not implemented yet.")

def is_parsable(self, source_code: str) -> bool:
"""
Check if the code is parsable
Args:
source_code: source code

Returns:
True if the code is parsable, False otherwise
"""
return JavaSitter.is_parsable(self, source_code)

def get_raw_ast(self, source_code: str) -> str:
"""
Get the raw AST
Args:
code: source code

Returns:
Tree: the raw AST
"""
return JavaSitter.get_raw_ast(self, source_code)

def get_call_graph(self) -> DiGraph:
"""Returns the call graph of the Java code.

Expand Down
78 changes: 64 additions & 14 deletions cldk/analysis/java/treesitter/javasitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from itertools import groupby
from typing import List, Set, Dict

from tree_sitter import Language, Node, Parser, Query, Tree
import tree_sitter_java as tsjava
from tree_sitter import Language, Node, Parser, Query

Expand Down Expand Up @@ -51,10 +51,49 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
bool
True if the method is in the class, False otherwise.
"""
methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", class_body)
methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)",
class_body)

return method_name not in {method.node.text.decode() for method in methods_in_class}

def is_parsable(self, code: str) -> bool:
"""
Check if the code is parsable
Args:
code: source code

Returns:
True if the code is parsable, False otherwise
"""

def syntax_error(node):
if node.type == "ERROR":
return True
try:
for child in node.children:
if syntax_error(child):
return True
except RecursionError as err:
return True

return False

tree = self.parser.parse(bytes(code, "utf-8"))
if tree is not None:
return not syntax_error(tree.root_node)
return False

def get_raw_ast(self, code: str) -> Tree:
"""
Get the raw AST
Args:
code: source code

Returns:
Tree: the raw AST
"""
return self.parser.parse(bytes(code, "utf-8"))

def get_all_imports(self, source_code: str) -> Set[str]:
"""Get a list of all the imports in a class.

Expand All @@ -64,7 +103,8 @@ def get_all_imports(self, source_code: str) -> Set[str]:
Returns:
Set[str]: A set of all the imports in the class.
"""
import_declerations: Captures = self.frame_query_and_capture_output(query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code)
import_declerations: Captures = self.frame_query_and_capture_output(
query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code)
return {capture.node.text.decode() for capture in import_declerations}

def get_pacakge_name(self, source_code: str) -> str:
Expand All @@ -76,7 +116,8 @@ def get_pacakge_name(self, source_code: str) -> str:
Returns:
str: The package name.
"""
package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", code_to_process=source_code)
package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)",
code_to_process=source_code)
if package_name:
return package_name[0].node.text.decode().replace("package ", "").replace(";", "")
return None
Expand All @@ -102,7 +143,8 @@ def get_superclass(self, source_code: str) -> str:
Returns:
Set[str]: A set of all the superclasses in the class.
"""
superclass: Captures = self.frame_query_and_capture_output(query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code)
superclass: Captures = self.frame_query_and_capture_output(
query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code)

if len(superclass) == 0:
return ""
Expand All @@ -119,7 +161,9 @@ def get_all_interfaces(self, source_code: str) -> Set[str]:
Set[str]: A set of all the interfaces implemented by the class.
"""

interfaces = self.frame_query_and_capture_output("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", code_to_process=source_code)
interfaces = self.frame_query_and_capture_output(
"(class_declaration (super_interfaces (type_list (type_identifier) @interface)))",
code_to_process=source_code)
return {interface.node.text.decode() for interface in interfaces}

def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures:
Expand All @@ -138,7 +182,8 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca

def get_method_name_from_declaration(self, method_name_string: str) -> str:
"""Get the method name from the method signature."""
captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", method_name_string)
captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)",
method_name_string)

return captures[0].node.text.decode()

Expand All @@ -147,7 +192,8 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str:
Using the tree-sitter query, extract the method name from the method invocation.
"""

captures: Captures = self.frame_query_and_capture_output("(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation)
captures: Captures = self.frame_query_and_capture_output(
"(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation)
return captures[0].node.text.decode()

def safe_ascend(self, node: Node, ascend_count: int) -> Node:
Expand Down Expand Up @@ -352,7 +398,8 @@ def get_method_return_type(self, source_code: str) -> str:
The return type of the method.
"""

type_references: Captures = self.frame_query_and_capture_output("(method_declaration type: ((type_identifier) @type_id))", source_code)
type_references: Captures = self.frame_query_and_capture_output(
"(method_declaration type: ((type_identifier) @type_id))", source_code)

return type_references[0].node.text.decode()

Expand All @@ -379,9 +426,9 @@ def collect_leaf_token_values(node):
if len(node.children) == 0:
if filter_by_node_type is not None:
if node.type in filter_by_node_type:
lexical_tokens.append(code[node.start_byte : node.end_byte])
lexical_tokens.append(code[node.start_byte: node.end_byte])
else:
lexical_tokens.append(code[node.start_byte : node.end_byte])
lexical_tokens.append(code[node.start_byte: node.end_byte])
else:
for child in node.children:
collect_leaf_token_values(child)
Expand Down Expand Up @@ -415,9 +462,11 @@ def remove_all_comments(self, source_code: str) -> str:
pruned_source_code = self.make_pruned_code_prettier(source_code)

# Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments).
comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=source_code)
comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)",
code_to_process=source_code)

comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", code_to_process=source_code)
comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)",
code_to_process=source_code)

for capture in comment_blocks:
pruned_source_code = pruned_source_code.replace(capture.node.text.decode(), "")
Expand All @@ -441,7 +490,8 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str:
The prettified pruned code.
"""
# First remove remaining block comments
block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=pruned_code)
block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)",
code_to_process=pruned_code)

for capture in block_comments:
pruned_code = pruned_code.replace(capture.node.text.decode(), "")
Expand Down
24 changes: 23 additions & 1 deletion cldk/analysis/python/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,29 @@ def get_method_details(self, method_signature: str) -> PyMethod:
"""
return self.analysis_backend.get_method_details(self.source_code, method_signature)

def get_imports(self) -> List[PyImport]:
def is_parsable(self, source_code: str) -> bool:
"""
Check if the code is parsable
Args:
source_code: source code

Returns:
True if the code is parsable, False otherwise
"""
return PythonSitter.is_parsable(self, source_code)

def get_raw_ast(self, source_code: str) -> str:
"""
Get the raw AST
Args:
code: source code

Returns:
Tree: the raw AST
"""
return PythonSitter.get_raw_ast(self, source_code)

def get_imports(self) -> List[PyImport]:
"""
Given an application or a source code, get all the imports
"""
Expand Down
41 changes: 39 additions & 2 deletions cldk/analysis/python/treesitter/python_sitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@
from pathlib import Path
from typing import List

from tree_sitter import Language, Parser, Query, Node
from tree_sitter import Language, Parser, Query, Node, Tree
import tree_sitter_python as tspython

from cldk.models.python.models import PyMethod, PyClass, PyArg, PyImport, PyModule, PyCallSite
from cldk.models.treesitter import Captures
from cldk.utils.treesitter.tree_sitter_utils import TreeSitterUtils
Expand All @@ -41,6 +40,44 @@ def __init__(self) -> None:
self.parser: Parser = Parser(self.language)
self.utils: TreeSitterUtils = TreeSitterUtils()

def is_parsable(self, code: str) -> bool:
"""
Check if the code is parsable
Args:
code: source code

Returns:
True if the code is parsable, False otherwise
"""
def syntax_error(node):
if node.type == "ERROR":
return True
try:
for child in node.children:
if syntax_error(child):
return True
except RecursionError as err:
print(err)
return True

return False

tree = self.parser.parse(bytes(code, "utf-8"))
if tree is not None:
return not syntax_error(tree.root_node)
return False

def get_raw_ast(self, code: str) -> Tree:
"""
Get the raw AST
Args:
code: source code

Returns:
Tree: the raw AST
"""
return self.parser.parse(bytes(code, "utf-8"))

def get_all_methods(self, module: str) -> List[PyMethod]:
"""
Get all the methods in the specific module.
Expand Down
7 changes: 7 additions & 0 deletions cldk/analysis/symbol_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ def __init__(self) -> None:
Language agnostic functions
"""

@abstractmethod
def is_parsable(self, **kwargs):
"""
Given a full code or a snippet, returns whether code is in right structure or hence parsable
"""
pass

@abstractmethod
def get_methods(self, **kwargs):
"""
Expand Down
13 changes: 13 additions & 0 deletions tests/tree_sitter/python/test_python_tree_sitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ def setUp(self):

def tearDown(self):
"""Runs after each test case"""
def test_is_parasable(self):
module_str = """
@staticmethod
def foo() -> None:
pass
class Person:
def __init__(self, name: str, age: int):
self.name = name
self.age = age
@staticmethod
def __str__(self):"
"""
self.assertFalse(self.python_tree_sitter.is_parsable(module_str))

def test_get_all_methods(self):
module_str = """
Expand Down