diff --git a/cldk/analysis/java/java.py b/cldk/analysis/java/java.py index 5b1bab1..a258412 100644 --- a/cldk/analysis/java/java.py +++ b/cldk/analysis/java/java.py @@ -24,6 +24,7 @@ from networkx import DiGraph from cldk.analysis import SymbolTable, CallGraph, AnalysisLevel +from cldk.analysis.java.treesitter import JavaSitter from cldk.models.java import JCallable from cldk.models.java import JApplication from cldk.models.java.models import JCompilationUnit, JMethodDetail, JType, JField @@ -175,6 +176,28 @@ def get_class_hierarchy(self) -> DiGraph: raise NotImplementedError(f"Support for this functionality has not been implemented yet.") raise NotImplementedError("Class hierarchy is not implemented yet.") + def is_parsable(self, source_code: str) -> bool: + """ + Check if the code is parsable + Args: + source_code: source code + + Returns: + True if the code is parsable, False otherwise + """ + return JavaSitter.is_parsable(self, source_code) + + def get_raw_ast(self, source_code: str) -> str: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return JavaSitter.get_raw_ast(self, source_code) + def get_call_graph(self) -> DiGraph: """Returns the call graph of the Java code. diff --git a/cldk/analysis/java/treesitter/javasitter.py b/cldk/analysis/java/treesitter/javasitter.py index 8206fd8..973db91 100644 --- a/cldk/analysis/java/treesitter/javasitter.py +++ b/cldk/analysis/java/treesitter/javasitter.py @@ -20,7 +20,7 @@ from itertools import groupby from typing import List, Set, Dict - +from tree_sitter import Language, Node, Parser, Query, Tree import tree_sitter_java as tsjava from tree_sitter import Language, Node, Parser, Query @@ -51,10 +51,49 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool: bool True if the method is in the class, False otherwise. """ - methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", class_body) + methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", + class_body) return method_name not in {method.node.text.decode() for method in methods_in_class} + def is_parsable(self, code: str) -> bool: + """ + Check if the code is parsable + Args: + code: source code + + Returns: + True if the code is parsable, False otherwise + """ + + def syntax_error(node): + if node.type == "ERROR": + return True + try: + for child in node.children: + if syntax_error(child): + return True + except RecursionError as err: + return True + + return False + + tree = self.parser.parse(bytes(code, "utf-8")) + if tree is not None: + return not syntax_error(tree.root_node) + return False + + def get_raw_ast(self, code: str) -> Tree: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return self.parser.parse(bytes(code, "utf-8")) + def get_all_imports(self, source_code: str) -> Set[str]: """Get a list of all the imports in a class. @@ -64,7 +103,8 @@ def get_all_imports(self, source_code: str) -> Set[str]: Returns: Set[str]: A set of all the imports in the class. """ - import_declerations: Captures = self.frame_query_and_capture_output(query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code) + import_declerations: Captures = self.frame_query_and_capture_output( + query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code) return {capture.node.text.decode() for capture in import_declerations} def get_pacakge_name(self, source_code: str) -> str: @@ -76,7 +116,8 @@ def get_pacakge_name(self, source_code: str) -> str: Returns: str: The package name. """ - package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", code_to_process=source_code) + package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", + code_to_process=source_code) if package_name: return package_name[0].node.text.decode().replace("package ", "").replace(";", "") return None @@ -102,7 +143,8 @@ def get_superclass(self, source_code: str) -> str: Returns: Set[str]: A set of all the superclasses in the class. """ - superclass: Captures = self.frame_query_and_capture_output(query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code) + superclass: Captures = self.frame_query_and_capture_output( + query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code) if len(superclass) == 0: return "" @@ -119,7 +161,9 @@ def get_all_interfaces(self, source_code: str) -> Set[str]: Set[str]: A set of all the interfaces implemented by the class. """ - interfaces = self.frame_query_and_capture_output("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", code_to_process=source_code) + interfaces = self.frame_query_and_capture_output( + "(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", + code_to_process=source_code) return {interface.node.text.decode() for interface in interfaces} def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: @@ -138,7 +182,8 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca def get_method_name_from_declaration(self, method_name_string: str) -> str: """Get the method name from the method signature.""" - captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", method_name_string) + captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", + method_name_string) return captures[0].node.text.decode() @@ -147,7 +192,8 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str: Using the tree-sitter query, extract the method name from the method invocation. """ - captures: Captures = self.frame_query_and_capture_output("(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation) + captures: Captures = self.frame_query_and_capture_output( + "(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation) return captures[0].node.text.decode() def safe_ascend(self, node: Node, ascend_count: int) -> Node: @@ -352,7 +398,8 @@ def get_method_return_type(self, source_code: str) -> str: The return type of the method. """ - type_references: Captures = self.frame_query_and_capture_output("(method_declaration type: ((type_identifier) @type_id))", source_code) + type_references: Captures = self.frame_query_and_capture_output( + "(method_declaration type: ((type_identifier) @type_id))", source_code) return type_references[0].node.text.decode() @@ -379,9 +426,9 @@ def collect_leaf_token_values(node): if len(node.children) == 0: if filter_by_node_type is not None: if node.type in filter_by_node_type: - lexical_tokens.append(code[node.start_byte : node.end_byte]) + lexical_tokens.append(code[node.start_byte: node.end_byte]) else: - lexical_tokens.append(code[node.start_byte : node.end_byte]) + lexical_tokens.append(code[node.start_byte: node.end_byte]) else: for child in node.children: collect_leaf_token_values(child) @@ -415,9 +462,11 @@ def remove_all_comments(self, source_code: str) -> str: pruned_source_code = self.make_pruned_code_prettier(source_code) # Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments). - comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=source_code) + comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", + code_to_process=source_code) - comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", code_to_process=source_code) + comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", + code_to_process=source_code) for capture in comment_blocks: pruned_source_code = pruned_source_code.replace(capture.node.text.decode(), "") @@ -441,7 +490,8 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str: The prettified pruned code. """ # First remove remaining block comments - block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=pruned_code) + block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", + code_to_process=pruned_code) for capture in block_comments: pruned_code = pruned_code.replace(capture.node.text.decode(), "") diff --git a/cldk/analysis/python/python.py b/cldk/analysis/python/python.py index 5852dae..a0bf36a 100644 --- a/cldk/analysis/python/python.py +++ b/cldk/analysis/python/python.py @@ -87,7 +87,29 @@ def get_method_details(self, method_signature: str) -> PyMethod: """ return self.analysis_backend.get_method_details(self.source_code, method_signature) - def get_imports(self) -> List[PyImport]: + def is_parsable(self, source_code: str) -> bool: + """ + Check if the code is parsable + Args: + source_code: source code + + Returns: + True if the code is parsable, False otherwise + """ + return PythonSitter.is_parsable(self, source_code) + + def get_raw_ast(self, source_code: str) -> str: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return PythonSitter.get_raw_ast(self, source_code) + + def get_imports(self) -> List[PyImport]: """ Given an application or a source code, get all the imports """ diff --git a/cldk/analysis/python/treesitter/python_sitter.py b/cldk/analysis/python/treesitter/python_sitter.py index 62caec1..dad49f0 100644 --- a/cldk/analysis/python/treesitter/python_sitter.py +++ b/cldk/analysis/python/treesitter/python_sitter.py @@ -23,9 +23,8 @@ from pathlib import Path from typing import List -from tree_sitter import Language, Parser, Query, Node +from tree_sitter import Language, Parser, Query, Node, Tree import tree_sitter_python as tspython - from cldk.models.python.models import PyMethod, PyClass, PyArg, PyImport, PyModule, PyCallSite from cldk.models.treesitter import Captures from cldk.utils.treesitter.tree_sitter_utils import TreeSitterUtils @@ -41,6 +40,44 @@ def __init__(self) -> None: self.parser: Parser = Parser(self.language) self.utils: TreeSitterUtils = TreeSitterUtils() + def is_parsable(self, code: str) -> bool: + """ + Check if the code is parsable + Args: + code: source code + + Returns: + True if the code is parsable, False otherwise + """ + def syntax_error(node): + if node.type == "ERROR": + return True + try: + for child in node.children: + if syntax_error(child): + return True + except RecursionError as err: + print(err) + return True + + return False + + tree = self.parser.parse(bytes(code, "utf-8")) + if tree is not None: + return not syntax_error(tree.root_node) + return False + + def get_raw_ast(self, code: str) -> Tree: + """ + Get the raw AST + Args: + code: source code + + Returns: + Tree: the raw AST + """ + return self.parser.parse(bytes(code, "utf-8")) + def get_all_methods(self, module: str) -> List[PyMethod]: """ Get all the methods in the specific module. diff --git a/cldk/analysis/symbol_table.py b/cldk/analysis/symbol_table.py index 0649251..79104ad 100644 --- a/cldk/analysis/symbol_table.py +++ b/cldk/analysis/symbol_table.py @@ -29,6 +29,13 @@ def __init__(self) -> None: Language agnostic functions """ + @abstractmethod + def is_parsable(self, **kwargs): + """ + Given a full code or a snippet, returns whether code is in right structure or hence parsable + """ + pass + @abstractmethod def get_methods(self, **kwargs): """ diff --git a/tests/tree_sitter/python/test_python_tree_sitter.py b/tests/tree_sitter/python/test_python_tree_sitter.py index 830d06f..5071ab9 100644 --- a/tests/tree_sitter/python/test_python_tree_sitter.py +++ b/tests/tree_sitter/python/test_python_tree_sitter.py @@ -14,6 +14,19 @@ def setUp(self): def tearDown(self): """Runs after each test case""" + def test_is_parasable(self): + module_str = """ + @staticmethod + def foo() -> None: + pass + class Person: + def __init__(self, name: str, age: int): + self.name = name + self.age = age + @staticmethod + def __str__(self):" + """ + self.assertFalse(self.python_tree_sitter.is_parsable(module_str)) def test_get_all_methods(self): module_str = """