From beb999a024254c935eea5b6967c5c25d8212469c Mon Sep 17 00:00:00 2001 From: Rangeet Pan Date: Wed, 2 Oct 2024 11:26:13 -0400 Subject: [PATCH 1/2] adding support for code parsing --- cldk/analysis/java/java.py | 12 +++++++++ cldk/analysis/java/treesitter/javasitter.py | 26 ++++++++++++++++++ cldk/analysis/python/python.py | 11 ++++++++ .../python/treesitter/python_sitter.py | 27 +++++++++++++++++++ cldk/analysis/symbol_table.py | 7 +++++ .../python/test_python_tree_sitter.py | 13 +++++++++ 6 files changed, 96 insertions(+) diff --git a/cldk/analysis/java/java.py b/cldk/analysis/java/java.py index f743d6b..a36dcdc 100644 --- a/cldk/analysis/java/java.py +++ b/cldk/analysis/java/java.py @@ -4,6 +4,7 @@ from networkx import DiGraph from cldk.analysis import SymbolTable, CallGraph, AnalysisLevel +from cldk.analysis.java.treesitter import JavaSitter from cldk.models.java import JCallable from cldk.models.java import JApplication from cldk.models.java.models import JCompilationUnit, JMethodDetail, JType, JField @@ -143,6 +144,17 @@ def get_class_hierarchy(self) -> DiGraph: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") raise NotImplementedError("Class hierarchy is not implemented yet.") + def is_parsable(self, source_code: str) -> bool: + """ + Check if the code is parsable + Args: + source_code: source code + + Returns: + True if the code is parsable, False otherwise + """ + return JavaSitter.is_parsable(self, source_code) + def get_call_graph(self) -> DiGraph: """ Returns the call graph of the Java code. diff --git a/cldk/analysis/java/treesitter/javasitter.py b/cldk/analysis/java/treesitter/javasitter.py index 403e304..dbd9355 100644 --- a/cldk/analysis/java/treesitter/javasitter.py +++ b/cldk/analysis/java/treesitter/javasitter.py @@ -33,6 +33,32 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool: return method_name not in {method.node.text.decode() for method in methods_in_class} + def is_parsable(self, code: str) -> bool: + """ + Check if the code is parsable + Args: + code: source code + + Returns: + True if the code is parsable, False otherwise + """ + def syntax_error(node): + if node.type == "ERROR": + return True + try: + for child in node.children: + if syntax_error(child): + return True + except RecursionError as err: + return True + + return False + + tree = self.parser.parse(bytes(code, "utf-8")) + if tree is not None: + return not syntax_error(tree.root_node) + return False + def get_all_imports(self, source_code: str) -> Set[str]: """Get a list of all the imports in a class. diff --git a/cldk/analysis/python/python.py b/cldk/analysis/python/python.py index a38f2b5..7a8160b 100644 --- a/cldk/analysis/python/python.py +++ b/cldk/analysis/python/python.py @@ -67,6 +67,17 @@ def get_method_details(self, method_signature: str) -> PyMethod: """ return self.analysis_backend.get_method_details(self.source_code, method_signature) + def is_parsable(self, source_code: str) -> bool: + """ + Check if the code is parsable + Args: + source_code: source code + + Returns: + True if the code is parsable, False otherwise + """ + return PythonSitter.is_parsable(self, source_code) + def get_imports(self) -> List[PyImport]: """ Given an application or a source code, get all the imports diff --git a/cldk/analysis/python/treesitter/python_sitter.py b/cldk/analysis/python/treesitter/python_sitter.py index df5dcc3..9afd9c6 100644 --- a/cldk/analysis/python/treesitter/python_sitter.py +++ b/cldk/analysis/python/treesitter/python_sitter.py @@ -21,6 +21,33 @@ def __init__(self) -> None: self.parser: Parser = Parser(self.language) self.utils: TreeSitterUtils = TreeSitterUtils() + def is_parsable(self, code: str) -> bool: + """ + Check if the code is parsable + Args: + code: source code + + Returns: + True if the code is parsable, False otherwise + """ + def syntax_error(node): + if node.type == "ERROR": + return True + try: + for child in node.children: + if syntax_error(child): + return True + except RecursionError as err: + print(err) + return True + + return False + + tree = self.parser.parse(bytes(code, "utf-8")) + if tree is not None: + return not syntax_error(tree.root_node) + return False + def get_all_methods(self, module: str) -> List[PyMethod]: """ Get all the methods in the specific module. diff --git a/cldk/analysis/symbol_table.py b/cldk/analysis/symbol_table.py index d26e9cf..b32af3c 100644 --- a/cldk/analysis/symbol_table.py +++ b/cldk/analysis/symbol_table.py @@ -9,6 +9,13 @@ def __init__(self) -> None: Language agnostic functions ''' + @abstractmethod + def is_parsable(self, **kwargs): + """ + Given a full code or a snippet, returns whether code is in right structure or hence parsable + """ + pass + @abstractmethod def get_methods(self, **kwargs): """ diff --git a/tests/tree_sitter/python/test_python_tree_sitter.py b/tests/tree_sitter/python/test_python_tree_sitter.py index 830d06f..5071ab9 100644 --- a/tests/tree_sitter/python/test_python_tree_sitter.py +++ b/tests/tree_sitter/python/test_python_tree_sitter.py @@ -14,6 +14,19 @@ def setUp(self): def tearDown(self): """Runs after each test case""" + def test_is_parasable(self): + module_str = """ + @staticmethod + def foo() -> None: + pass + class Person: + def __init__(self, name: str, age: int): + self.name = name + self.age = age + @staticmethod + def __str__(self):" + """ + self.assertFalse(self.python_tree_sitter.is_parsable(module_str)) def test_get_all_methods(self): module_str = """ From 6d74330bbbab66f8016d5d9e58d4433584c99c70 Mon Sep 17 00:00:00 2001 From: Rangeet Pan Date: Wed, 16 Oct 2024 13:38:51 -0400 Subject: [PATCH 2/2] changes --- cldk/analysis/java/java.py | 21 ++++++-- cldk/analysis/java/treesitter/javasitter.py | 52 ++++++++++++++----- cldk/analysis/python/python.py | 11 ++++ .../python/treesitter/python_sitter.py | 14 ++++- 4 files changed, 77 insertions(+), 21 deletions(-) diff --git a/cldk/analysis/java/java.py b/cldk/analysis/java/java.py index a36dcdc..ec9630a 100644 --- a/cldk/analysis/java/java.py +++ b/cldk/analysis/java/java.py @@ -146,15 +146,26 @@ def get_class_hierarchy(self) -> DiGraph: def is_parsable(self, source_code: str) -> bool: """ - Check if the code is parsable - Args: - source_code: source code + Check if the code is parsable + Args: + source_code: source code - Returns: - True if the code is parsable, False otherwise + Returns: + True if the code is parsable, False otherwise """ return JavaSitter.is_parsable(self, source_code) + def get_raw_ast(self, source_code: str) -> str: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return JavaSitter.get_raw_ast(self, source_code) + def get_call_graph(self) -> DiGraph: """ Returns the call graph of the Java code. diff --git a/cldk/analysis/java/treesitter/javasitter.py b/cldk/analysis/java/treesitter/javasitter.py index dbd9355..a52f735 100644 --- a/cldk/analysis/java/treesitter/javasitter.py +++ b/cldk/analysis/java/treesitter/javasitter.py @@ -1,6 +1,6 @@ from itertools import groupby from typing import List, Set, Dict -from tree_sitter import Language, Node, Parser, Query +from tree_sitter import Language, Node, Parser, Query, Tree import tree_sitter_java as tsjava from cldk.models.treesitter import Captures @@ -29,7 +29,8 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool: bool True if the method is in the class, False otherwise. """ - methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", class_body) + methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", + class_body) return method_name not in {method.node.text.decode() for method in methods_in_class} @@ -42,6 +43,7 @@ def is_parsable(self, code: str) -> bool: Returns: True if the code is parsable, False otherwise """ + def syntax_error(node): if node.type == "ERROR": return True @@ -59,6 +61,17 @@ def syntax_error(node): return not syntax_error(tree.root_node) return False + def get_raw_ast(self, code: str) -> Tree: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return self.parser.parse(bytes(code, "utf-8")) + def get_all_imports(self, source_code: str) -> Set[str]: """Get a list of all the imports in a class. @@ -68,7 +81,8 @@ def get_all_imports(self, source_code: str) -> Set[str]: Returns: Set[str]: A set of all the imports in the class. """ - import_declerations: Captures = self.frame_query_and_capture_output(query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code) + import_declerations: Captures = self.frame_query_and_capture_output( + query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code) return {capture.node.text.decode() for capture in import_declerations} def get_pacakge_name(self, source_code: str) -> str: @@ -80,7 +94,8 @@ def get_pacakge_name(self, source_code: str) -> str: Returns: str: The package name. """ - package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", code_to_process=source_code) + package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", + code_to_process=source_code) if package_name: return package_name[0].node.text.decode().replace("package ", "").replace(";", "") return None @@ -106,7 +121,8 @@ def get_superclass(self, source_code: str) -> str: Returns: Set[str]: A set of all the superclasses in the class. """ - superclass: Captures = self.frame_query_and_capture_output(query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code) + superclass: Captures = self.frame_query_and_capture_output( + query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code) if len(superclass) == 0: return "" @@ -123,7 +139,9 @@ def get_all_interfaces(self, source_code: str) -> Set[str]: Set[str]: A set of all the interfaces implemented by the class. """ - interfaces = self.frame_query_and_capture_output("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", code_to_process=source_code) + interfaces = self.frame_query_and_capture_output( + "(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", + code_to_process=source_code) return {interface.node.text.decode() for interface in interfaces} def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: @@ -142,7 +160,8 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca def get_method_name_from_declaration(self, method_name_string: str) -> str: """Get the method name from the method signature.""" - captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", method_name_string) + captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", + method_name_string) return captures[0].node.text.decode() @@ -151,7 +170,8 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str: Using the tree-sitter query, extract the method name from the method invocation. """ - captures: Captures = self.frame_query_and_capture_output("(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation) + captures: Captures = self.frame_query_and_capture_output( + "(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation) return captures[0].node.text.decode() def safe_ascend(self, node: Node, ascend_count: int) -> Node: @@ -356,7 +376,8 @@ def get_method_return_type(self, source_code: str) -> str: The return type of the method. """ - type_references: Captures = self.frame_query_and_capture_output("(method_declaration type: ((type_identifier) @type_id))", source_code) + type_references: Captures = self.frame_query_and_capture_output( + "(method_declaration type: ((type_identifier) @type_id))", source_code) return type_references[0].node.text.decode() @@ -383,9 +404,9 @@ def collect_leaf_token_values(node): if len(node.children) == 0: if filter_by_node_type is not None: if node.type in filter_by_node_type: - lexical_tokens.append(code[node.start_byte : node.end_byte]) + lexical_tokens.append(code[node.start_byte: node.end_byte]) else: - lexical_tokens.append(code[node.start_byte : node.end_byte]) + lexical_tokens.append(code[node.start_byte: node.end_byte]) else: for child in node.children: collect_leaf_token_values(child) @@ -419,9 +440,11 @@ def remove_all_comments(self, source_code: str) -> str: pruned_source_code = self.make_pruned_code_prettier(source_code) # Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments). - comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=source_code) + comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", + code_to_process=source_code) - comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", code_to_process=source_code) + comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", + code_to_process=source_code) for capture in comment_blocks: pruned_source_code = pruned_source_code.replace(capture.node.text.decode(), "") @@ -445,7 +468,8 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str: The prettified pruned code. """ # First remove remaining block comments - block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=pruned_code) + block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", + code_to_process=pruned_code) for capture in block_comments: pruned_code = pruned_code.replace(capture.node.text.decode(), "") diff --git a/cldk/analysis/python/python.py b/cldk/analysis/python/python.py index 7a8160b..f0c05fd 100644 --- a/cldk/analysis/python/python.py +++ b/cldk/analysis/python/python.py @@ -78,6 +78,17 @@ def is_parsable(self, source_code: str) -> bool: """ return PythonSitter.is_parsable(self, source_code) + def get_raw_ast(self, source_code: str) -> str: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return PythonSitter.get_raw_ast(self, source_code) + def get_imports(self) -> List[PyImport]: """ Given an application or a source code, get all the imports diff --git a/cldk/analysis/python/treesitter/python_sitter.py b/cldk/analysis/python/treesitter/python_sitter.py index 9afd9c6..ff156a6 100644 --- a/cldk/analysis/python/treesitter/python_sitter.py +++ b/cldk/analysis/python/treesitter/python_sitter.py @@ -3,9 +3,8 @@ from pathlib import Path from typing import List -from tree_sitter import Language, Parser, Query, Node +from tree_sitter import Language, Parser, Query, Node, Tree import tree_sitter_python as tspython - from cldk.models.python.models import PyMethod, PyClass, PyArg, PyImport, PyModule, PyCallSite from cldk.models.treesitter import Captures from cldk.utils.treesitter.tree_sitter_utils import TreeSitterUtils @@ -48,6 +47,17 @@ def syntax_error(node): return not syntax_error(tree.root_node) return False + def get_raw_ast(self, code: str) -> Tree: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return self.parser.parse(bytes(code, "utf-8")) + def get_all_methods(self, module: str) -> List[PyMethod]: """ Get all the methods in the specific module.