diff --git a/CONTRIBUTE.md b/CONTRIBUTE.md new file mode 100644 index 0000000..b64eede --- /dev/null +++ b/CONTRIBUTE.md @@ -0,0 +1,26 @@ +# Contributing to Codellm-DevKit + +You can report issues or open a pull request (PR) to suggest changes. + +## Reporting an issue + +To report an issue, or to suggest an idea for a change that you haven't +had time to write-up yet: +1. [Review existing issues](https://github.com/IBM/codellm-devkit/issues) to see if a similar issue has been opened or discussed. +2. [Open an +issue](https://github.com/IBM/codellm-devkit/issues/new). Be sure to include any helpful information, such as your Kubernetes environment details, error messages, or logs that you might have. + + +## Suggesting a change + +To suggest a change to this repository, [submit a pull request](https://github.com/IBM/codellm-devkit/pulls) with the complete set of changes that you want to suggest. Before creating a PR, make sure that your changes pass all of the tests. + +The test suite can be executed with the following command in the top-level folder: +``` +pytest +``` + +Also, please make sure that your changes pass static checks such as code styles by executing the following command: +``` +pre-commit run --all-files +``` diff --git a/README.md b/README.md index 2cc13f0..a001623 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,69 @@ -# CodeLLM-devkit: A Python library for seamless interaction with CodeLLMs +# CodeLLM-Devkit: A Python library for seamless interaction with CodeLLMs -![image](cldk.png) -codellm-devkit provides unified language to get off-the-shelf static analysis for multiple programming languages and support for applying those analyses for code LLM use cases. +![image](./docs/assets/cldk.png) +[![Python 3.11](https://img.shields.io/badge/python-3.11-blue.svg)](https://www.python.org/downloads/release/python-3110/) + +## Prerequisites + +- Python 3.11+ +- Poetry (see [doc](https://python-poetry.org/docs/)) + +## Installation + +Obtain Codellm-DevKit from below: + +```bash +git clone git@github.com:IBM/codellm-devkit.git /path/to/cloned/repo +``` + +Install CodeLLM-Devkit + +```bash +pip install -U /path/to/cloned/repo +``` + +## Usage + +### 1. Obtain sample application to experiment with (we'll use Daytrader 8 as an example) + +```bash +wget https://github.com/OpenLiberty/sample.daytrader8/archive/refs/tags/v1.2.tar.gz +``` + +Extract the archive and navigate to the `daytrader8` directory: + +```bash +tar -xvf v1.2.tar.gz +tar -xvf v1.2.tar.gz +``` + +Save the location to where daytrader8 is extracted to, as we will need it later: + +```bash +export DAYTRADER8_DIR=/path/to/sample.daytrader8-1.2 +``` + +Then you can use the following command to run the codeanalyzer: + +```python +import os +from rich import print # Optional, for pretty printing. +from cldk import CLDK +from cldk.models.java.models import * + +# Initialize the Codellm-DevKit object with the project directory, language, and backend. +ns = CLDK( + project_dir=os.getenv("DAYTRADER8_DIR"), # Change this to the path of the project you want to analyze. + language="java", # The language of the project. + backend="codeanalyzer", # The backend to use for the analysis. + analysis_db="/tmp", # A temporary directory to store the analysis results. + sdg=True, # Generate the System Dependence Graph (SDG) for the project. +) + +# Get the java application view for the project. The application view is a representation of the project as a graph with all the classes, methods, and fields. +app: JApplication = ns.preprocessing.get_application_view() + +# Get all the classes in the project. +classes_dict = ns.preprocessing.get_all_classes() +print(classes_dict) +``` \ No newline at end of file diff --git a/cldk/__init__.py b/cldk/__init__.py new file mode 100644 index 0000000..4829ad0 --- /dev/null +++ b/cldk/__init__.py @@ -0,0 +1,3 @@ +from .core import CLDK + +__all__ = ["CLDK"] diff --git a/cldk/analysis/__init__.py b/cldk/analysis/__init__.py new file mode 100644 index 0000000..17b84c9 --- /dev/null +++ b/cldk/analysis/__init__.py @@ -0,0 +1,7 @@ + +from .call_graph import CallGraph +from .program_dependence_graph import ProgramDependenceGraph +from .system_dependence_graph import SystemDependenceGraph +from .symbol_table import SymbolTable + +__all__ = ["CallGraph", "ProgramDependenceGraph", "SystemDependenceGraph", "SymbolTable"] diff --git a/cldk/analysis/c/__init__.py b/cldk/analysis/c/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/analysis/c/treesitter/__init__.py b/cldk/analysis/c/treesitter/__init__.py new file mode 100644 index 0000000..036c721 --- /dev/null +++ b/cldk/analysis/c/treesitter/__init__.py @@ -0,0 +1,3 @@ +from cldk.analysis.c.treesitter.c_sitter import CSitter + +__all__ = ["CSitter"] diff --git a/cldk/analysis/c/treesitter/c_sitter.py b/cldk/analysis/c/treesitter/c_sitter.py new file mode 100644 index 0000000..e3954fd --- /dev/null +++ b/cldk/analysis/c/treesitter/c_sitter.py @@ -0,0 +1,510 @@ +from typing import List +from tree_sitter import Language, Parser, Query, Node +import tree_sitter_c as tsc + +from cldk.models.c.models import CFunction, CImport, CParameter, CTranslationUnit, COutput +from cldk.models.treesitter import Captures + + +class CSitter: + """ + Tree sitter for C use cases. + """ + + def __init__(self) -> None: + self.language: Language = Language(tsc.language()) + self.parser: Parser = Parser(self.language) + + def get_all_functions(self, code: str) -> List[CFunction]: + """ + Get all the functions in the provided code. + + Parameters + ---------- + code: the code you want to analyse. + + Returns + ------- + List[CFunction] + returns all the function details within the provided code. + """ + + return [self.__get_function_details(code, capture.node) for capture in self.__get_function_nodes(code)] + + def get_imports(self, code: str) -> List[CImport]: + """ + Get all the imports in the provided code. + + Parameters + ---------- + code: the code you want to analyse. + + Returns + ------- + List[CImport] + returns all the imports within the provided code. + """ + + query = """(preproc_include) @import""" + captures: Captures = self.__frame_query_and_capture_output(query, code) + imports: List[CImport] = [] + for capture in captures: + path_node: Node = capture.node.child_by_field_name("path") + text: str = path_node.text.decode() + if path_node.type == "system_lib_string": + imports.append(CImport(value=text[1 : len(text) - 1], is_system=True)) + elif path_node.type == "string_literal": + imports.append(CImport(value=text[1 : len(text) - 1], is_system=False)) + else: + imports.append(CImport(value=text, is_system=False)) + + return imports + + def get_translation_unit_details(self, code: str) -> CTranslationUnit: + """ + Given the code of a C translation unit, return the details. + + Parameters + ---------- + code : str + The source code of the translation unit. + + Returns + ------- + CTranslationUnit + The details of the given translation unit. + """ + + return CTranslationUnit( + functions=self.get_all_functions(code), + imports=self.get_imports(code), + ) + + def __get_function_details(self, original_code: str, node: Node) -> CFunction: + """ + Extract the details of a function from a tree-sitter node. + + Parameters + ---------- + original_code : str + The original code, used to extract the tree-sitter node. + node : Node + The function tree-sitter node we want to evaluate. + + Returns + ------- + CFunction + The extracted details of the function. + """ + + nb_pointers = self.__count_pointers(node.child_by_field_name("declarator")) + return_type: str = self.__get_function_return_type(node) + if return_type != "function": + return_type = return_type + nb_pointers * "*" + + output: COutput = COutput( + type=return_type, + is_reference=return_type == "function" or nb_pointers > 0, + qualifiers=self.__get_type_qualifiers(node), + ) + + return CFunction( + name=self.__get_function_name(node), + code=node.text.decode(), + start_line=node.start_point[0], + end_line=node.end_point[0], + signature=self.__get_function_signature(original_code, node), + parameters=self.__get_function_parameters(node), + output=output, + comment=self.__get_comment(node), + specifiers=self.__get_storage_class_specifiers(node), + ) + + def __get_function_parameters(self, function_node: Node) -> List[CParameter]: + """ + Extract the parameters of a tree-sitter function node. + + Parameters + ---------- + function_node : Node + The function node whose parameters we want to extract. + + Returns + ------- + List[CParameter] + The parameters of the given function node. + """ + + query = """(function_declarator ((parameter_list) @function.parameters))""" + parameters_list: Captures = self.__query_node_and_capture_output(query, function_node) + + if not parameters_list: + return [] + + params: dict[str, CParameter] = self.__get_parameter_details(parameters_list) + + # for old-style function definition: + # https://www.gnu.org/software/c-intro-and-ref/manual/html_node/Old_002dStyle-Function-Definitions.html + + for child in function_node.children: + if child.type == "declaration": + for tup in self.__extract_parameter_declarations(child): + name, parameter = tup + params[name] = parameter + + # filter out params without type + return [param[1] for param in params.items() if param[1].type] + + def __frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: + """Frame a query for the tree-sitter parser. + + Parameters + ---------- + query : str + The query to frame. + code_to_process : str + The code to process. + + Returns + ------- + Captures + The list of tree-sitter captures. + """ + + framed_query: Query = self.language.query(query) + tree = self.parser.parse(bytes(code_to_process, "utf-8")) + return Captures(framed_query.captures(tree.root_node)) + + def __query_node_and_capture_output(self, query: str, node: Node) -> Captures: + """Frame a query for the tree-sitter parser and query the given tree-sitter node. + + Parameters + ---------- + query : str + The query to frame. + node : Node + The root node used for querying. + + Returns + ------- + Captures + The list of tree-sitter captures. + """ + + framed_query: Query = self.language.query(query) + return Captures(framed_query.captures(node)) + + def __get_function_nodes(self, code: str) -> Captures: + """Parse the given code and extract tree-sitter function nodes. + + Parameters + ---------- + code : str + The input code. + + Returns + ------- + Captures + The list of tree-sitter captures. + """ + + query = """((function_definition) @function)""" + return self.__frame_query_and_capture_output(query, code) + + def __get_function_name(self, function_node: Node) -> str: + """ + Extract the function name from a tree-sitter function node. + + Parameters + ---------- + function_node : Node + The function node whose name we want to extract. + + Returns + ------- + str + The name of the function. + """ + + query = """(function_declarator ((identifier) @function.name))""" + function_name_node: Node = self.__query_node_and_capture_output(query, function_node)[0].node + return function_name_node.text.decode() + + def __get_function_return_type(self, function_node: Node) -> str: + """ + Extracts the return type of a tree-sitter function node. + + Parameters + ---------- + function_node : Node + The function node whose return type we want to extract. + + Returns + ------- + str + The return type of a function or function, if the return is a function pointer. + """ + + # TODO: not sure if this is correct + # if there's more that 1 function declaration type, we consider it a function pointer + if self.__count_function_declarations(function_node.child_by_field_name("declarator")) > 1: + return "function" + + type_node = function_node.child_by_field_name("type") + + return type_node.text.decode() if type_node.type != "struct_specifier" else type_node.child_by_field_name("name").text.decode() + + def __get_function_signature(self, code: str, function_node: Node) -> str: + """ + Extracts the function signature from a tree-sitter function node. + + Parameters + ---------- + code : str + The original code that was used to extract the function node. + function_node : Node + The function node whose signature we want to extract. + + Returns + ------- + str + The signature of the function. + """ + + body_node: Node = function_node.child_by_field_name("body") + start_byte = function_node.start_byte + end_byte = body_node.start_byte + code_bytes = bytes(code, "utf-8") + signature = code_bytes[start_byte:end_byte] + + return signature.decode().strip() + + def __get_type_qualifiers(self, node: Node) -> List[str]: + """ + Extract the type qualifiers from a given tree-sitter node. + + Paramaters + ---------- + node : Node + The node whose type qulifiers we want to extract. + + Returns + ------- + List[str] + The list of type qualifiers. + """ + + if not node or not node.children: + return [] + + return [child.text.decode() for child in node.children if child.type == "type_qualifier"] + + def __get_storage_class_specifiers(self, node: Node) -> List[str]: + """ + Extract the storage class specifiers from a given tree-sitter node. + + Paramaters + ---------- + node : Node + The node whose storage class speciers we want to extract. + + Returns + ------- + List[str] + The list of storage class specifiers. + """ + + if not node or not node.children: + return [] + + return [child.text.decode() for child in node.children if child.type == "storage_class_specifier"] + + def __count_pointers(self, node: Node) -> int: + """ + Count the number of consecutive pointers for a tree-sitter node. + + Parameters + ---------- + node : Node + The tree-siter node we want to evaluate. + + Returns + ------- + int + The number of consecutive pointers present in the given tree-sitter node. + """ + + count = 0 + curr_node = node + while curr_node and curr_node.type == "pointer_declarator": + count += 1 + curr_node = curr_node.child_by_field_name("declarator") + + return count + + def __count_function_declarations(self, node: Node) -> int: + """ + Counts the number of function declaration nodes for a tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter node we want to evaluate. + + Returns + ------- + int + The number of function delacration nodes present in the given tree-sitter node. + """ + + if not node or not node.children: + return 0 + + sum = 1 if node.type == "function_declarator" else 0 + for child in node.children: + sum += self.__count_function_declarations(child) + + return sum + + def __get_parameter_details(self, parameters_list: Captures) -> dict[str, CParameter]: + """ + Extract parameter details from a list of tree-sitter parameters. + + Parameters + ---------- + parameters_list : Captures + The parameter list node captures. + + Returns + ------- + Dict[str, CParameter] + A dictionary of parameter details. + """ + + params: dict[str, CParameter] = {} + + for parameters in parameters_list: + if not parameters or not parameters.node.children: + continue + for param in parameters.node.children: + # old c style + if param.type == "identifier": + name, parameter = self.__extract_simple_parameter(param, "") + params[name] = parameter + elif param.type == "variadic_parameter": + name, parameter = self.__extract_simple_parameter(param, "variadic") + params[name] = parameter + elif param.type == "parameter_declaration": + for tup in self.__extract_parameter_declarations(param): + name, parameter = tup + params[name] = parameter + + return params + + def __extract_simple_parameter(self, node: Node, parameter_type: str) -> tuple[str, CParameter]: + name: str = node.text.decode() + parameter: CParameter = CParameter( + type=parameter_type, + qualifiers=[], + specifiers=[], + is_reference=False, + name=name, + ) + + return (name, parameter) + + def __extract_parameter_declarations(self, node: Node) -> List[tuple[str, CParameter]]: + query = """((identifier) @name)""" + captures: Captures = self.__query_node_and_capture_output(query, node) + + # no name found, skip this node + if len(captures) == 0: + return [] + + parameters: List[tuple[str, CParameter]] = [] + for capture in captures: + parameters.append(self.__extract_parameter_declaration(node, capture.node)) + + return parameters + + def __extract_parameter_declaration(self, parent_node: Node, identifier_node: Node) -> tuple[str, CParameter]: + name = identifier_node.text.decode() + + nb_function_declarations = self.__count_function_declarations(parent_node) + # we have a function pointer + if nb_function_declarations > 0: + parameter: CParameter = CParameter( + type="function", + qualifiers=[], # TODO: not sure if this is correct + specifiers=[], # TODO: not sure if this is correct + is_reference=True, + name=name, + ) + return (name, parameter) + + type_node = parent_node.child_by_field_name("type") + + param_type: str = type_node.text.decode() if type_node.type != "struct_specifier" else type_node.child_by_field_name("name").text.decode() + type_augmentor = self.__augment_type(identifier_node, parent_node.type) + + parameter = CParameter( + type=param_type + type_augmentor, + qualifiers=self.__get_type_qualifiers(parent_node), + specifiers=self.__get_storage_class_specifiers(parent_node), + is_reference=type_augmentor.startswith("*"), + name=name, + ) + + return (name, parameter) + + def __augment_type(self, identifier_node: Node, stop_node_type: str) -> str: + """ + Augment types with pointer and array details. + """ + + # not sure about this one + type_augmentor = "" + pointer_augmentor = "" + array_augmentor = "" + curr_node = identifier_node.parent + while curr_node and curr_node.type != stop_node_type: + if curr_node.type == "pointer_declarator": + pointer_augmentor = f"*{pointer_augmentor}" + elif curr_node.type == "array_declarator": + size_node = curr_node.child_by_field_name("size") + size: str = "" + if size_node: + size = size_node.text.decode() + array_augmentor = f"{array_augmentor}[{size}]" + elif curr_node.type == "parenthesized_declarator": + type_augmentor = f"({pointer_augmentor}{type_augmentor}{array_augmentor})" + pointer_augmentor = "" + array_augmentor = "" + + curr_node = curr_node.parent + + return f"{pointer_augmentor}{type_augmentor}{array_augmentor}" + + def __get_comment(self, node: Node) -> str: + """ + Extract the comment associated with a tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter node whose + + Returns + ------- + str + The comment associeted with the given node. + """ + + docs = [] + curr_node = node + while curr_node.prev_named_sibling and curr_node.prev_named_sibling.type == "comment": + curr_node = curr_node.prev_named_sibling + text = curr_node.text.decode() + docs.append(text) + + return "\n".join(reversed(docs)) diff --git a/cldk/analysis/call_graph.py b/cldk/analysis/call_graph.py new file mode 100644 index 0000000..2901a09 --- /dev/null +++ b/cldk/analysis/call_graph.py @@ -0,0 +1,69 @@ +from abc import ABC, abstractmethod + + +class CallGraph(ABC): + def __init__(self) -> None: + super().__init__() + + @abstractmethod + def get_callees(self, **kwargs): + """ + Given a source code, get all the callees + """ + pass + + @abstractmethod + def get_callers(self, **kwargs): + """ + Given a source code, get all the callers + """ + pass + + @abstractmethod + def get_call_graph(self, **kwargs): + """ + Given an application, get the call graph + """ + pass + + @abstractmethod + def get_call_graph_json(self, **kwargs): + """ + Given an application, get call graph in JSON format + """ + pass + + @abstractmethod + def get_class_call_graph(self, **kwargs): + """ + Given an application and a class, get call graph + """ + pass + + @abstractmethod + def get_entry_point_classes(self, **kwargs): + """ + Given an application, get all the entry point classes + """ + pass + + @abstractmethod + def get_entry_point_methods(self, **kwargs): + """ + Given an application, get all the entry point methods + """ + pass + + @abstractmethod + def get_service_entry_point_classes(self, **kwargs): + """ + Given an application, get all the service entry point classes + """ + pass + + @abstractmethod + def get_service_entry_point_methods(self, **kwargs): + """ + Given an application, get all the service entry point methods + """ + pass diff --git a/cldk/analysis/go/__init__.py b/cldk/analysis/go/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/analysis/go/treesitter/__init__.py b/cldk/analysis/go/treesitter/__init__.py new file mode 100644 index 0000000..9b2531d --- /dev/null +++ b/cldk/analysis/go/treesitter/__init__.py @@ -0,0 +1,3 @@ +from cldk.analysis.go.treesitter.go_sitter import GoSitter + +__all__ = ["GoSitter"] diff --git a/cldk/analysis/go/treesitter/go_sitter.py b/cldk/analysis/go/treesitter/go_sitter.py new file mode 100644 index 0000000..48e1f58 --- /dev/null +++ b/cldk/analysis/go/treesitter/go_sitter.py @@ -0,0 +1,451 @@ +from typing import List +from tree_sitter import Language, Parser, Query, Node +import tree_sitter_go as tsgo + +from cldk.models.go.models import GoCallable, GoImport, GoParameter, GoSourceFile +from cldk.models.treesitter import Captures + + +class GoSitter: + """ + Tree sitter for Go use cases. + """ + + def __init__(self) -> None: + self.language: Language = Language(tsgo.language()) + self.parser: Parser = Parser(self.language) + + def get_all_functions(self, code: str) -> List[GoCallable]: + """ + Get all the functions and methods in the provided code. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[GoCallable] + All the function and method details within the provided code. + """ + + query = """ + ((function_declaration) @function) + ((method_declaration) @method) + """ + + callables: List[GoCallable] = [] + captures: Captures = self.__frame_query_and_capture_output(query, code) + for capture in captures: + if capture.name == "function": + callables.append(self.__get_function_details(capture.node)) + elif capture.name == "method": + callables.append(self.__get_method_details(capture.node)) + + return callables + + def get_imports(self, code: str) -> List[GoImport]: + """ + Get all the imports in the provided code. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[GoImport] + All the imports within the provided code. + """ + + query = """ + (import_declaration + (import_spec) @import + ) + + (import_declaration + (import_spec_list + (import_spec) @import + ) + ) + """ + + return [self.__extract_import_details(capture.node) for capture in self.__frame_query_and_capture_output(query, code)] + + def get_source_file_details(self, source_file: str) -> GoSourceFile: + """ + Get the details of the provided source file. + + Parameters + ---------- + source_file : str + The source file code you want to analyse. + + Returns + ------- + GoSourceFile + The details of the provided source file code. + """ + + return GoSourceFile( + imports=self.get_imports(source_file), + callables=self.get_all_functions(source_file), + ) + + def __frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: + """Frame a query for the tree-sitter parser. + + Parameters + ---------- + query : str + The query to frame. + code_to_process : str + The code to process. + + Returns + ------- + Captures + The list of tree-sitter captures. + """ + + framed_query: Query = self.language.query(query) + tree = self.parser.parse(bytes(code_to_process, "utf-8")) + return Captures(framed_query.captures(tree.root_node)) + + def __query_node_and_capture_output(self, query: str, node: Node) -> Captures: + """Frame a query for the tree-sitter parser and query the given tree-sitter node. + + Parameters + ---------- + query : str + The query to frame. + node : Node + The root node used for querying. + + Returns + ------- + Captures + The list of tree-sitter captures. + """ + + framed_query: Query = self.language.query(query) + return Captures(framed_query.captures(node)) + + def __get_function_details(self, node: Node) -> GoCallable: + """ + Extract the function details from a tree-sitter function node. + + Parameters + ---------- + node : Node + The tree-sitter function node whose details we want. + + Returns + ------- + GoCallable + The function details. + """ + + name: str = self.__get_name(node) + return GoCallable( + name=name, + signature=self.__get_signature(node), + code=node.text.decode(), + start_line=node.start_point[0], + end_line=node.end_point[0], + modifiers=["public"] if name[0].isupper() else ["private"], + comment=self.__get_comment(node), + parameters=self.__get_parameters(node), + return_types=self.__get_return_types(node), + ) + + def __get_method_details(self, node: Node) -> GoCallable: + """ + Extract the method details from a tree-sitter method node. + + Parameters + ---------- + node : Node + The tree-sitter method node whose details we want. + + Returns + ------- + GoCallable + The method details. + """ + + name: str = self.__get_name(node) + return GoCallable( + name=name, + signature=self.__get_signature(node), + code=node.text.decode(), + start_line=node.start_point[0], + end_line=node.end_point[0], + modifiers=["public"] if name[0].isupper() else ["private"], + comment=self.__get_comment(node), + parameters=self.__get_parameters(node), + return_types=self.__get_return_types(node), + receiver=self.__get_receiver(node), + ) + + def __get_name(self, node: Node) -> str: + """ + Extract the name of a tree-sitter function or method node. + + Parameters + ---------- + node : Node + The tree-sitter node whose name we want. + + Returns + ------- + str + The method/function name. + """ + + return node.child_by_field_name("name").text.decode() + + def __get_signature(self, node: Node) -> str: + """ + Extract the signature of a tree-sitter function or method node. + + Parameters + ---------- + node : Node + The tree-sitter node whose signature we want. + + Returns + ------- + str + The method/function signature. + """ + + signature = "" + # only methods have a receiver + receiver_node: Node = node.child_by_field_name("receiver") + if receiver_node: + signature += receiver_node.text.decode() + + if signature: + signature += " " + + name = self.__get_name(node) + signature += name + + # generics type, optional field available only for functions + type_params_node: Node = node.child_by_field_name("type_parameters") + if type_params_node: + signature += type_params_node.text.decode() + + signature += node.child_by_field_name("parameters").text.decode() + + result_node: Node = node.child_by_field_name("result") + if result_node: + signature += " " + result_node.text.decode() + + return signature + + def __get_comment(self, node: Node) -> str: + """ + Extract the comment associated with a tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter node whose docs we want. + + Returns + ------- + str + The comment associated with the given node. + """ + + docs = [] + curr_node = node + while curr_node.prev_named_sibling and curr_node.prev_named_sibling.type == "comment": + curr_node = curr_node.prev_named_sibling + text = curr_node.text.decode() + docs.append(text) + + return "\n".join(reversed(docs)) + + def __get_parameters(self, node: Node) -> List[GoParameter]: + """ + Extract the parameters from a tree-sitter function/method node. + + Parameters + ---------- + node : Node + The tree-sitter node whose parameters we want. + + Returns + ------- + List[GoParameter] + The list of parameter details. + """ + + parameters_node: Node = node.child_by_field_name("parameters") + if not parameters_node or not parameters_node.children: + return [] + + parameters: List[GoParameter] = [] + for child in parameters_node.children: + if child.type == "parameter_declaration": + parameters.extend(self.__extract_parameters(child)) + elif child.type == "variadic_parameter_declaration": + parameters.append(self.__extract_variadic_parameter(child)) + + return parameters + + def __get_receiver(self, node: Node) -> GoParameter: + """ + Extract the receiver from a tree-sitter method node. + + Parameters + ---------- + node : Node + The tree-sitter node whose receiver we want. + + Returns + ------- + GoParameter + The receiver details. + """ + + receiver_node: Node = node.child_by_field_name("receiver") + + # it must have 1 non-variadic child + for child in receiver_node.children: + if child.type == "parameter_declaration": + return self.__extract_parameters(child)[0] + + def __get_return_types(self, node: Node) -> List[str]: + """ + Extract the return types from a callable tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter node whose types we want. + + Returns + ------- + List[str] + The list of types returned by the callable. Empty list, if it does not return. + """ + + result_node: Node = node.child_by_field_name("result") + if not result_node: + return [] + + if result_node.type == "parameter_list": + return_types: List[str] = [] + for child in result_node.children: + if child.type == "parameter_declaration": + return_types.extend([param.type for param in self.__extract_parameters(child)]) + elif child.type == "variadic_parameter_declaration": + return_types.append(self.__extract_variadic_parameter(child).type) + + return return_types + else: + # TODO: might need to be more specific + return [result_node.text.decode()] + + def __extract_parameters(self, parameter_declaration_node: Node) -> List[GoParameter]: + """ + Extract parameter details from a tree-sitter parameter declaration node. + + Parameters + ---------- + parameter_declaration_node : Node + The tree-sitter node whose details we want. + + Returns + ------- + List[GoParameter] + The list of parameter details. + """ + + type_node: Node = parameter_declaration_node.child_by_field_name("type") + param_type = type_node.text.decode() + is_reference_type = param_type.startswith("*") + + query = """((identifier) @name)""" + captures: Captures = self.__query_node_and_capture_output(query, parameter_declaration_node) + + # name is optional + if len(captures) == 0: + return [ + GoParameter( + type=param_type, + is_reference=is_reference_type, + is_variadic=False, + ) + ] + + return [ + GoParameter( + name=capture.node.text.decode(), + type=param_type, + is_reference=is_reference_type, + is_variadic=False, + ) + for capture in captures + ] + + def __extract_variadic_parameter(self, variadic_parameter_node: Node) -> GoParameter: + """ + Extract parameter details from a tree-sitter variadic declaration node. + + Parameters + ---------- + variadic_parameter_node : Node + The tree-sitter node whose details we want. + + Returns + ------- + GoParameter + The details of the variadic parameter. + """ + + name: str = None + # name is not mandatory + name_node: Node = variadic_parameter_node.child_by_field_name("name") + if name_node: + name = name_node.text.decode() + + type_node: Node = variadic_parameter_node.child_by_field_name("type") + param_type = type_node.text.decode() + + return GoParameter( + name=name, + type="..." + param_type, + is_reference=param_type.startswith("*"), + is_variadic=True, + ) + + def __extract_import_details(self, node: Node) -> GoImport: + """ + Extract the import details from a tree-sitter import spec node. + + Parameters + ---------- + node : Node + The import spec node tree-sitter node whose details we want. + + Returns + ------- + GoImport + The import details. + """ + + name_node: Node = node.child_by_field_name("name") + path_node: Node = node.child_by_field_name("path") + path = path_node.text.decode() + + return GoImport( + name=name_node.text.decode() if name_node else None, + path=path[1 : len(path) - 1], + ) diff --git a/cldk/analysis/java/__init__.py b/cldk/analysis/java/__init__.py new file mode 100644 index 0000000..5378268 --- /dev/null +++ b/cldk/analysis/java/__init__.py @@ -0,0 +1,3 @@ +from .java import JavaAnalysis + +__all__ = ["JavaAnalysis"] diff --git a/cldk/analysis/java/codeanalyzer/__init__.py b/cldk/analysis/java/codeanalyzer/__init__.py new file mode 100644 index 0000000..a3773e2 --- /dev/null +++ b/cldk/analysis/java/codeanalyzer/__init__.py @@ -0,0 +1,9 @@ +from .codeanalyzer import JCodeanalyzer + + +""" +Download the codeanalyzer.jar file from the latest release on the codeanalyzer repository. +""" + + +__all__ = ["JCodeanalyzer"] diff --git a/cldk/analysis/java/codeanalyzer/bin/.gitignore b/cldk/analysis/java/codeanalyzer/bin/.gitignore new file mode 100644 index 0000000..2eaf286 --- /dev/null +++ b/cldk/analysis/java/codeanalyzer/bin/.gitignore @@ -0,0 +1 @@ +codeanalyzer \ No newline at end of file diff --git a/cldk/analysis/java/codeanalyzer/bin/__init__.py b/cldk/analysis/java/codeanalyzer/bin/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/analysis/java/codeanalyzer/codeanalyzer.py b/cldk/analysis/java/codeanalyzer/codeanalyzer.py new file mode 100644 index 0000000..b9bb4da --- /dev/null +++ b/cldk/analysis/java/codeanalyzer/codeanalyzer.py @@ -0,0 +1,789 @@ +import re +import json +import shlex +import requests +import networkx as nx +from pathlib import Path +import subprocess +from subprocess import CompletedProcess +from urllib.request import urlretrieve +from datetime import datetime +from importlib import resources + +from networkx import DiGraph + +from cldk.analysis.java.treesitter import JavaSitter +from cldk.models.java import JGraphEdges +from cldk.models.java.models import JApplication, JCallable, JField, JMethodDetail, JType, JCompilationUnit +from typing import Dict, List, Tuple +from typing import Union + +from cldk.utils.exceptions.exceptions import CodeanalyzerExecutionException + +import logging + +logger = logging.getLogger(__name__) + + +class JCodeanalyzer: + """A class for building the application view of a Java application using Codeanalyzer. + + Parameters + ---------- + project_dir : str or Path + The path to the root of the Java project. + analysis_json_path : str or Path or None + The path to save the intermediate codeanalysis ouputs. If None, we'll read from the pipe. + sdg : bool + If True, the system dependency graph will be generated with a more indepth analysis. Default is False. + eager_analysis : bool + If True, the analysis will be done eagerly, i.e., a new analysis will be done every time the object is created. + Default is False. + use_graalvm_binary : bool + If True, the codeanalyzer binary from GraalVM will be used. By default, the codeanalyzer jar from the latest + release on GitHub will be used. + Methods + ------- + _init_codeanalyzer(project_dir, analysis_json_path) + Initializes the codeanalyzer database. + """ + + def __init__( + self, + project_dir: Union[str, Path], + source_code: str | None, + analysis_backend_path: Union[str, Path, None], + analysis_json_path: Union[str, Path, None], + analysis_level: str, + use_graalvm_binary: bool, + eager_analysis: bool, + ) -> None: + self.project_dir = project_dir + self.source_code = source_code + self.analysis_backend_path = analysis_backend_path + self.analysis_json_path = analysis_json_path + self.use_graalvm_binary = use_graalvm_binary + self.eager_analysis = eager_analysis + self.analysis_level = analysis_level + # Attributes related the Java code analysis... + self.call_graph: DiGraph | None = None + self.application = None + + @staticmethod + def _download_or_update_code_analyzer(filepath: Path) -> str: + """ + Downloads the codeanalyzer jar from the latest release on GitHub. + + Parameters + ---------- + filepath : str + The path to save the codeanalyzer jar. + + Returns + ------- + str + The path to the downloaded codeanalyzer jar file. + """ + url = "https://api.github.com/repos/IBM/codenet-minerva-code-analyzer/releases/latest" + response = requests.get(url) + date_format = "%Y%m%dT%H%M%S" + if response.status_code == 200: + for asset in response.json().get("assets", []): + if asset["name"] == "codeanalyzer.jar": + download_url = asset["browser_download_url"] + pattern = r"(\d{8}T\d{6})" + match = re.search(pattern, download_url) + if match: + datetime_str = match.group(0) + else: + raise Exception(f"Release URL {download_url} does not contain a datetime pattern.") + + # Look for codeanalyzer.YYYYMMDDTHHMMSS.jar in the filepath + current_codeanalyzer_jars = [jarfile for jarfile in filepath.glob("*.jar")] + if not any(current_codeanalyzer_jars): + logger.info(f"Codeanalzyer jar is not found. Downloading the latest version.") + filename = filepath / f"codeanalyzer.{datetime_str}.jar" + urlretrieve(download_url, filename) + return filename.__str__() + + current_codeanalyzer_jar_name = current_codeanalyzer_jars[0] + match = re.search(pattern, current_codeanalyzer_jar_name.__str__()) + if match: + current_datetime_str = match.group(0) + + if datetime.strptime(datetime_str, date_format) > datetime.strptime(current_datetime_str, date_format): + logger.info(f"Codeanalzyer jar is outdated. Downloading the latest version.") + # Remove the older codeanalyzer jar + for jarfile in current_codeanalyzer_jars: + jarfile.unlink() + # Download the newer codeanalyzer jar + filename = filepath / f"codeanalyzer.{datetime_str}.jar" + urlretrieve(download_url, filename) + else: + filename = current_codeanalyzer_jar_name + logger.info(f"Codeanalzyer jar is already at the latest version.") + else: + filename = current_codeanalyzer_jar_name + + return filename.__str__() + else: + raise Exception(f"Failed to fetch release warn: {response.status_code} {response.text}") + + def _get_application(self) -> JApplication: + """ + Returns the application view of the Java code. + + Returns + ------- + JApplication + The application view of the Java code. + """ + if self.application is None: + self.application = self._init_codeanalyzer() + return self.application + + def _get_codeanalyzer_exec(self) -> List[str]: + """ + Returns the executable command for codeanalyzer. + + Returns + ------- + List[str] + The executable command for codeanalyzer. + + Notes + ----- + Some selection criteria for the codeanalyzer analysis_backend: + 1. If the use_graalvm_binary flag is set, the codeanalyzer binary from GraalVM will be used. + 2. If the analysis_backend_path is provided by the user, the codeanalyzer.jar from the analysis_backend_path will be used. + 3. If the analysis_backend_path is not provided, the latest codeanalyzer.jar from the GitHub release will be first downloaded. + """ + + if self.use_graalvm_binary: + with resources.as_file(resources.files("cldk.analysis.java.codeanalyzer.bin") / "codeanalyzer") as codeanalyzer_bin_path: + codeanalyzer_exec = shlex.split(codeanalyzer_bin_path.__str__()) + else: + if self.analysis_backend_path: + analysis_backend_path = Path(analysis_backend_path) + logger.info(f"Using codeanalyzer.jar from {analysis_backend_path}") + codeanalyzer_exec = shlex.split(f"java -jar {analysis_backend_path / 'codeanalyzer.jar'}") + else: + # Since the path to codeanalyzer.jar was not provided, we'll download the latest version from GitHub. + with resources.as_file(resources.files("cldk.analysis.java.codeanalyzer.jar")) as codeanalyzer_jar_path: + # Download the codeanalyzer jar if it doesn't exist, update if it's outdated, + # do nothing if it's up-to-date. + codeanalyzer_jar_file = self._download_or_update_code_analyzer(codeanalyzer_jar_path) + codeanalyzer_exec = shlex.split(f"java -jar {codeanalyzer_jar_file}") + return codeanalyzer_exec + + def _init_codeanalyzer(self, analysis_level=1) -> JApplication: + """Initializes the Codeanalyzer. + Returns + ------- + JApplication + The application view of the Java code with the analysis results. + Raises + ------ + CodeanalyzerExecutionException + If there is an error running Codeanalyzer. + """ + + codeanalyzer_exec = self._get_codeanalyzer_exec() + + if self.analysis_json_path is None: + logger.info("Reading analysis from the pipe.") + codeanalyzer_args = codeanalyzer_exec + shlex.split(f"-i {Path(self.project_dir)} --analysis-level={analysis_level}") + try: + logger.info(f"Running codeanalyzer: {' '.join(codeanalyzer_args)}") + console_out: CompletedProcess[str] = subprocess.run( + codeanalyzer_args, + capture_output=True, + text=True, + check=True, + ) + return JApplication(**json.loads(console_out.stdout)) + except Exception as e: + raise CodeanalyzerExecutionException(str(e)) from e + + else: + analysis_json_path_file = Path(self.analysis_json_path).joinpath("analysis.json") + if not analysis_json_path_file.exists() or self.eager_analysis: + # If the analysis file does not exist, we'll run the analysis. Alternately, if the eager_analysis + # flag is set, we'll run the analysis every time the object is created. This will happen regradless + # of the existence of the analysis file. + # Create the executable command for codeanalyzer. + codeanalyzer_args = codeanalyzer_exec + shlex.split(f"-i {Path(self.project_dir)} --analysis-level={analysis_level} -o {self.analysis_json_path}") + + try: + logger.info(f"Running codeanalyzer subprocess with args {codeanalyzer_args}") + subprocess.run( + codeanalyzer_args, + capture_output=True, + text=True, + check=True, + ) + if not analysis_json_path_file.exists(): + raise CodeanalyzerExecutionException("Codeanalyzer did not generate the analysis file.") + + except Exception as e: + raise CodeanalyzerExecutionException(str(e)) from e + + with open(analysis_json_path_file) as f: + data = json.load(f) + return JApplication(**data) + + def _codeanalyzer_single_file(self): + """ + Invokes codeanalyzer in a single file mode. + + Returns + ------- + JApplication + The application view of the Java code with the analysis results. + """ + # self.source_code: str = re.sub(r"[\r\n\t\f\v]+", lambda x: " " if x.group() in "\t\f\v" else " ", self.source_code) + codeanalyzer_exec = self._get_codeanalyzer_exec() + codeanalyzer_args = ["--source-analysis", self.source_code] + codeanalyzer_cmd = codeanalyzer_exec + codeanalyzer_args + try: + print(f"Running {' '.join(codeanalyzer_cmd)}") + logger.info(f"Running {' '.join(codeanalyzer_cmd)}") + console_out: CompletedProcess[str] = subprocess.run(codeanalyzer_cmd, capture_output=True, text=True, check=True) + if console_out.returncode != 0: + raise CodeanalyzerExecutionException(console_out.stderr) + return JApplication(**json.loads(console_out.stdout)) + except Exception as e: + raise CodeanalyzerExecutionException(str(e)) from e + + def get_symbol_table(self) -> Dict[str, JCompilationUnit]: + """ + Returns the symbol table of the Java code. + + Returns + ------- + Dict[str, JCompilationUnit] + The symbol table of the Java code. + """ + if self.application is None: + self.application = self._init_codeanalyzer() + return self.application.symbol_table + + def get_application_view(self) -> JApplication: + """ + Returns the application view of the Java code. + + Returns: + -------- + JApplication + The application view of the Java code. + """ + if self.source_code: + # This branch is triggered when a single file is being analyzed. + self.application = self._codeanalyzer_single_file() + return self.application + else: + if self.application is None: + self.application = self._init_codeanalyzer() + return self.application + + def get_system_dependency_graph(self) -> list[JGraphEdges]: + """ + Run the codeanalyzer to get the system dependency graph. + + Returns + ------- + list[JGraphEdges] + The system dependency graph. + """ + if self.application.system_dependency_graph is None: + self.application = self._init_codeanalyzer(analysis_level=2) + + return self.application.system_dependency_graph + + def _generate_call_graph(self, using_symbol_table) -> DiGraph: + """ + Generates the call graph of the Java code. + + Returns: + -------- + DiGraph + The call graph of the Java code. + """ + cg = nx.DiGraph() + if using_symbol_table: + NotImplementedError("Call graph generation using symbol table is not implemented yet.") + else: + sdg = self.get_system_dependency_graph() + tsu = JavaSitter() + edge_list = [ + ( + (jge.source.method.signature, jge.source.klass), + (jge.target.method.signature, jge.target.klass), + { + "type": jge.type, + "weight": jge.weight, + "calling_lines": tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature), + }, + ) + for jge in sdg + if jge.type == "CONTROL_DEP" or jge.type == "CALL_DEP" + ] + for jge in sdg: + cg.add_node( + (jge.source.method.signature, jge.source.klass), + method_detail=jge.source, + ) + cg.add_node( + (jge.target.method.signature, jge.target.klass), + method_detail=jge.target, + ) + cg.add_edges_from(edge_list) + return cg + + def get_class_hierarchy(self) -> DiGraph: + """ + Returns the class hierarchy of the Java code. + + Returns: + -------- + DiGraph + The class hierarchy of the Java code. + """ + + def get_call_graph(self) -> DiGraph: + """ + Get call graph of the Java code. + + Returns: + -------- + DiGraph + The call graph of the Java code. + """ + if self.analysis_level == "symbol_table": + self.call_graph = self._generate_call_graph(using_symbol_table=True) + if self.call_graph is None: + self.call_graph = self._generate_call_graph(using_symbol_table=False) + return self.call_graph + + def get_call_graph_json(self) -> str: + """ + serialize callgraph to json + """ + callgraph_list = [] + edges = list(self.call_graph.edges.data("calling_lines")) + for edge in edges: + callgraph_dict = {} + callgraph_dict["source_method_signature"] = edge[0][0] + callgraph_dict["source_method_body"] = self.call_graph.nodes[edge[0]]["method_detail"].method.code + callgraph_dict["source_class"] = edge[0][1] + callgraph_dict["target_method_signature"] = edge[1][0] + callgraph_dict["target_method_body"] = self.call_graph.nodes[edge[1]]["method_detail"].method.code + callgraph_dict["target_class"] = edge[1][1] + callgraph_dict["calling_lines"] = edge[2] + callgraph_list.append(callgraph_dict) + return json.dumps(callgraph_list) + + def get_all_callers(self, target_class_name: str, target_method_signature: str) -> Dict: + """ + Get all the caller details for a given java method. + + Returns: + -------- + Dict + Caller details in a dictionary. + """ + + caller_detail_dict = {} + if (target_method_signature, target_class_name) not in self.call_graph.nodes(): + return caller_detail_dict + + in_edge_view = self.call_graph.in_edges( + nbunch=( + target_method_signature, + target_class_name, + ), + data=True, + ) + caller_detail_dict["caller_details"] = [] + caller_detail_dict["target_method"] = self.call_graph.nodes[(target_method_signature, target_class_name)]["method_detail"] + + for source, target, data in in_edge_view: + cm = {"caller_method": self.call_graph.nodes[source]["method_detail"], "calling_lines": data["calling_lines"]} + caller_detail_dict["caller_details"].append(cm) + return caller_detail_dict + + def get_all_callees(self, source_class_name: str, source_method_signature: str) -> Dict: + """ + Get all the callee details for a given java method. + + Returns: + -------- + Dict + Callee details in a dictionary. + """ + callee_detail_dict = {} + if (source_method_signature, source_class_name) not in self.call_graph.nodes(): + return callee_detail_dict + + out_edge_view = self.call_graph.out_edges(nbunch=(source_method_signature, source_class_name), data=True) + + callee_detail_dict["callee_details"] = [] + callee_detail_dict["source_method"] = self.call_graph.nodes[(source_method_signature, source_class_name)]["method_detail"] + for source, target, data in out_edge_view: + cm = {"callee_method": self.call_graph.nodes[target]["method_detail"]} + cm["calling_lines"] = data["calling_lines"] + callee_detail_dict["callee_details"].append(cm) + return callee_detail_dict + + def get_all_methods_in_application(self) -> Dict[str, Dict[str, JCallable]]: + """ + Returns a dictionary of all methods in the Java code with + qualified class name as key and dictionary of methods in that class + as value + + Returns: + -------- + Dict[str, Dict[str, JCallable]]: + A dictionary of dictionaries of all methods in the Java code. + """ + + class_method_dict = {} + class_dict = self.get_all_classes() + for k, v in class_dict.items(): + class_method_dict[k] = v.callable_declarations + return class_method_dict + + def get_all_classes(self) -> Dict[str, JType]: + """ + Returns a dictionary of all classes in the Java code. + + Returns: + -------- + Dict[str, JType] + A dict of all classes in the Java code, with qualified class names as keys + """ + + class_dict = {} + symtab = self.get_symbol_table() + for v in symtab.values(): + class_dict.update(v.type_declarations) + return class_dict + + def get_class(self, qualified_class_name) -> JType: + """ + Returns a class given qualified class name. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + JClassOrInterface + A class for the given qualified class name. + """ + symtab = self.get_symbol_table() + for _, v in symtab.items(): + if qualified_class_name in v.type_declarations.keys(): + return v.type_declarations.get(qualified_class_name) + + def get_method(self, qualified_class_name, method_signature) -> JCallable: + """ + Returns a method given qualified method name. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + method_signature : str + The signature of the method. + + Returns: + -------- + JCallable + A method for the given qualified method name. + """ + symtab = self.get_symbol_table() + for v in symtab.values(): + if qualified_class_name in v.type_declarations.keys(): + ci = v.type_declarations[qualified_class_name] + for cd in ci.callable_declarations.keys(): + if cd == method_signature: + return ci.callable_declarations[cd] + + def get_java_file(self, qualified_class_name) -> str: + """ + Returns a class given qualified class name. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + str + Java file name containing the given qualified class. + """ + symtab = self.get_symbol_table() + for k, v in symtab.items(): + if (qualified_class_name) in v.type_declarations.keys(): + return k + + def get_java_compilation_unit(self, file_path: str) -> JCompilationUnit: + """ + Given the path of a Java source file, returns the compilation unit object from the symbol table. + + Parameters + ---------- + file_path : str + Absolute path to Java source file + + Returns + ------- + JCompilationUnit + Compilation unit object for Java source file + """ + + if self.application is None: + self.application = self._init_codeanalyzer() + return self.application[file_path] + + def get_all_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable]: + """ + Returns a dictionary of all methods in the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + Dict[str, JCallable] + A dictionary of all methods in the given class. + """ + ci = self.get_class(qualified_class_name) + if ci is None: + return {} + methods = {k: v for (k, v) in ci.callable_declarations.items() if v.is_constructor is False} + return methods + + def get_all_constructors(self, qualified_class_name) -> Dict[str, JCallable]: + """ + Returns a dictionary of all constructors of the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + Dict[str, JCallable] + A dictionary of all constructors of the given class. + """ + ci = self.get_class(qualified_class_name) + if ci is None: + return {} + constructors = {k: v for (k, v) in ci.callable_declarations.items() if v.is_constructor is True} + return constructors + + def get_all_sub_classes(self, qualified_class_name) -> Dict[str, JType]: + """ + Returns a dictionary of all sub-classes of the given class + Parameters + ---------- + qualified_class_name + + Returns + ------- + Dict[str, JType]: A dictionary of all sub-classes of the given class, and class details + """ + all_classes = self.get_all_classes() + sub_classes = {} + for cls in all_classes: + if qualified_class_name in all_classes[cls].implements_list or qualified_class_name in all_classes[cls].extends_list: + sub_classes[cls] = all_classes[cls] + return sub_classes + + def get_all_fields(self, qualified_class_name) -> List[JField]: + """ + Returns a list of all fields of the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JField] + A list of all fields of the given class. + """ + ci = self.get_class(qualified_class_name) + if ci is None: + logging.warning(f"Class {qualified_class_name} not found in the application view.") + return list() + return ci.field_declarations + + def get_all_nested_classes(self, qualified_class_name) -> List[JType]: + """ + Returns a list of all nested classes for the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JType] + A list of nested classes for the given class. + """ + ci = self.get_class(qualified_class_name) + if ci is None: + logging.warning(f"Class {qualified_class_name} not found in the application view.") + return list() + nested_classes = ci.nested_type_declerations + return [self.get_class(c) for c in nested_classes] # Assuming qualified nested class names + + def get_extended_classes(self, qualified_class_name) -> List[str]: + """ + Returns a list of all extended classes for the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[str] + A list of extended classes for the given class. + """ + ci = self.get_class(qualified_class_name) + if ci is None: + logging.warning(f"Class {qualified_class_name} not found in the application view.") + return list() + return ci.extends_list + + def get_implemented_interfaces(self, qualified_class_name) -> List[str]: + """ + Returns a list of all implemented interfaces for the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JType] + A list of implemented interfaces for the given class. + """ + ci = self.get_class(qualified_class_name) + if ci is None: + logging.warning(f"Class {qualified_class_name} not found in the application view.") + return list() + return ci.implements_list + + def get_class_call_graph(self, qualified_class_name: str, method_name: str | None = None) -> List[Tuple[JMethodDetail, JMethodDetail]]: + """ + A call graph for a given class and (optionally) filtered by a given method. + + Parameters + ---------- + qualified_class_name : str + The qualified name of the class. + method_name : str, optional + The name of the method in the class. + + Returns + ------- + List[Tuple[JMethodDetail, JMethodDetail]] + An edge list of the call graph for the given class and method. + + Notes + ----- + The class name must be fully qualified, e.g., "org.example.MyClass" and not "MyClass". + + Likewise, the + + + """ + # If the method name is not provided, we'll get the call graph for the entire class. + + # TODO: Implement class call graph generation @rahlk + + _class: JType = self.get_class(qualified_class_name) + + edge_list = [] + for method_signature, callable in _class.callable_declarations.items(): + for callsite in callable.callsites: + edge_list.append(((callable.signature, qualified_class_name),)) + + class_call_graph = nx.DiGraph() + + edge_list = [ + ( + (jge.source.method.signature, jge.source.klass), + (jge.target.method.signature, jge.target.klass), + { + "type": jge.type, + "weight": jge.weight, + "calling_lines": tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature), + }, + ) + for jge in sdg + if jge.type == "CONTROL_DEP" or jge.type == "CALL_DEP" + ] + + for jge in sdg: + class_call_graph.add_node( + (jge.source.method.signature, jge.source.klass), + method_detail=jge.source, + ) + class_call_graph.add_node( + (jge.target.method.signature, jge.target.klass), + method_detail=jge.target, + ) + class_call_graph.add_edges_from(edge_list) + + NotImplementedError("Class call graph generation is not implemented yet.") + + def get_all_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: + """ + Returns a dictionary of all entry point methods in the Java code with + qualified class name as key and dictionary of methods in that class + as value + + Returns: + -------- + Dict[str, Dict[str, JCallable]]: + A dictionary of dictionaries of entry point methods in the Java code. + """ + + class_method_dict = {} + class_dict = self.get_all_classes() + for k, v in class_dict.items(): + entry_point_methods = {method_name: callable_decl for (method_name, callable_decl) in v.callable_declarations.items() if callable_decl.is_entry_point is True} + class_method_dict[k] = entry_point_methods + return class_method_dict + + def get_all_entry_point_classes(self) -> Dict[str, JType]: + """ + Returns a dictionary of all entry point classes in the Java code. + + Returns: + -------- + Dict[str, JType] + A dict of all entry point classes in the Java code, with qualified class names as keys + """ + + class_dict = {} + symtab = self.get_symbol_table() + for val in symtab.values(): + class_dict.update((k, v) for k, v in val.type_declarations.items() if v.is_entry_point is True) + return class_dict diff --git a/cldk/analysis/java/codeanalyzer/jar/.gitignore b/cldk/analysis/java/codeanalyzer/jar/.gitignore new file mode 100644 index 0000000..d392f0e --- /dev/null +++ b/cldk/analysis/java/codeanalyzer/jar/.gitignore @@ -0,0 +1 @@ +*.jar diff --git a/cldk/analysis/java/codeanalyzer/jar/__init__.py b/cldk/analysis/java/codeanalyzer/jar/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/analysis/java/codeql/__init__.py b/cldk/analysis/java/codeql/__init__.py new file mode 100644 index 0000000..e8096d0 --- /dev/null +++ b/cldk/analysis/java/codeql/__init__.py @@ -0,0 +1,3 @@ +from .codeql import JCodeQL + +__all__ = ["JCodeQL"] diff --git a/cldk/analysis/java/codeql/backend.py b/cldk/analysis/java/codeql/backend.py new file mode 100644 index 0000000..23f0d54 --- /dev/null +++ b/cldk/analysis/java/codeql/backend.py @@ -0,0 +1,148 @@ +import subprocess +import tempfile +from pathlib import Path +import shlex +from typing import List +import pandas as pd +from pandas import DataFrame + +from cldk.utils.exceptions import CodeQLQueryExecutionException + + +class CodeQLQueryRunner: + """ + A class for executing CodeQL queries against a CodeQL database. + + Parameters + ---------- + database_path : str + The path to the CodeQL database. + + Attributes + ---------- + database_path : Path + The path to the CodeQL database. + temp_file_path : Path + The path to the temporary query file. + csv_output_file : Path + The path to the CSV output file. + temp_bqrs_file_path : Path + The path to the temporary bqrs file. + temp_qlpack_file : Path + The path to the temporary qlpack file. + + Methods + ------- + __enter__() + Context entry that creates temporary files to execute a CodeQL query. + execute(query_string, column_names) + Writes the query to the temporary file and executes it against the specified CodeQL database. + __exit__(exc_type, exc_val, exc_tb) + Clean up resources used by the CodeQL analysis. + + Raises + ------ + CodeQLQueryExecutionException + If there is an error executing the query. + """ + + def __init__(self, database_path: str): + self.database_path: Path = Path(database_path) + self.temp_file_path: Path = None + + def __enter__(self): + """ + Context entry that creates temporary files to execute a CodeQL query. + + Returns + ------- + instance : object + The instance of the class. + + Notes + ----- + This method creates temporary files to hold the query and store their paths. + """ + + # Create a temporary file to hold the query and store its path + temp_file = tempfile.NamedTemporaryFile("w", delete=False, suffix=".ql") + csv_file = tempfile.NamedTemporaryFile("w", delete=False, suffix=".csv") + bqrs_file = tempfile.NamedTemporaryFile("w", delete=False, suffix=".bqrs") + self.temp_file_path = Path(temp_file.name) + self.csv_output_file = Path(csv_file.name) + self.temp_bqrs_file_path = Path(bqrs_file.name) + + # Let's close the files, we'll reopen them by path when needed. + temp_file.close() + bqrs_file.close() + csv_file.close() + + # Create a temporary qlpack.yml file + self.temp_qlpack_file = self.temp_file_path.parent / "qlpack.yml" + with self.temp_qlpack_file.open("w") as f: + f.write("name: temp\n") + f.write("version: 1.0.0\n") + f.write("libraryPathDependencies: codeql/java-all\n") + + return self + + def execute(self, query_string: str, column_names: List[str]) -> DataFrame: + """Writes the query to the temporary file and executes it against the specified CodeQL database. + + Args: + query_string (str): The CodeQL query string to be executed. + column_names (List[str]): The list of column names for the CSV the CodeQL produces when we execute the query. + + Returns: + dict: A dictionary containing the resulting DataFrame. + + Raises: + RuntimeError: If the context manager is not entered using the 'with' statement. + CodeQLQueryExecutionException: If there is an error executing the query. + """ + if not self.temp_file_path: + raise RuntimeError("Context manager not entered. Use 'with' statement.") + + # Write the query to the temp file so we can execute it. + self.temp_file_path.write_text(query_string) + + # Construct and execute the CodeQL CLI command asking for a JSON output. + codeql_query_cmd = shlex.split(f"codeql query run {self.temp_file_path} --database={self.database_path} --output={self.temp_bqrs_file_path}",posix=False) + + call = subprocess.Popen(codeql_query_cmd, stdout=None, stderr=None) + _, err = call.communicate() + if call.returncode != 0: + raise CodeQLQueryExecutionException(f"Error executing query: {err.stderr}") + + # Convert the bqrs file to a CSV file + bqrs2csv_command = shlex.split(f"codeql bqrs decode --format=csv --output={self.csv_output_file} {self.temp_bqrs_file_path}",posix=False) + + # Read the CSV file content and cast it to a DataFrame + + call = subprocess.Popen(bqrs2csv_command, stdout=None, stderr=None) + _, err = call.communicate() + if call.returncode != 0: + raise CodeQLQueryExecutionException(f"Error executing query: {err.stderr}") + else: + return pd.read_csv( + self.csv_output_file, + header=None, + names=column_names, + skiprows=[0], + ) + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Clean up resources used by the CodeQL analysis. + + Deletes the temporary files created during the analysis, including the temporary file path, + the CSV output file, and the temporary QL pack file. + """ + if self.temp_file_path and self.temp_file_path.exists(): + self.temp_file_path.unlink() + + if self.csv_output_file and self.csv_output_file.exists(): + self.csv_output_file.unlink() + + if self.temp_qlpack_file and self.temp_qlpack_file.exists(): + self.temp_qlpack_file.unlink() diff --git a/cldk/analysis/java/codeql/codeql.py b/cldk/analysis/java/codeql/codeql.py new file mode 100644 index 0000000..32693d7 --- /dev/null +++ b/cldk/analysis/java/codeql/codeql.py @@ -0,0 +1,238 @@ +from pathlib import Path +import shlex +import subprocess +from networkx import DiGraph +from pandas import DataFrame +from cldk.models.java import JApplication +from cldk.analysis.java.codeql.backend import CodeQLQueryRunner +from tempfile import TemporaryDirectory +import atexit +import signal + +from cldk.utils.exceptions import CodeQLDatabaseBuildException +import networkx as nx +from typing import Union + + +class JCodeQL: + """A class for building the application view of a Java application using CodeQL. + + Parameters + ---------- + project_dir : str or Path + The path to the root of the Java project. + codeql_db : str or Path or None + The path to the CodeQL database. If None, a temporary directory is created to store the database. + + Attributes + ---------- + db_path : Path + The path to the CodeQL database. + + Methods + ------- + _init_codeql_db(project_dir, codeql_db) + Initializes the CodeQL database. + _build_application_view() + Builds the application view of the java application. + _build_call_graph() + Builds the call graph of the application. + get_application_view() + Returns the application view of the java application. + get_class_hierarchy() + Returns the class hierarchy of the java application. + get_call_graph() + Returns the call graph of the java application. + get_all_methods() + Returns all the methods of the java application. + get_all_classes() + Returns all the classes of the java application. + """ + + def __init__(self, project_dir: Union[str, Path], codeql_db: Union[str, Path, None]) -> None: + self.db_path = self._init_codeql_db(project_dir, codeql_db) + + @staticmethod + def _init_codeql_db(project_dir: Union[str, Path], codeql_db: Union[str, Path, None]) -> Path: + """Initializes the CodeQL database. + + Parameters + ---------- + project_dir : str or Path + The path to the root of the Java project. + codeql_db : str or Path or None + The path to the CodeQL database. If None, a temporary directory is created to store the database. + + Returns + ------- + Path + The path to the CodeQL database. + + Raises + ------ + CodeQLDatabaseBuildException + If there is an error building the CodeQL database. + """ + + # Cast to Path if the project_dir is a string. + project_dir = Path(project_dir) if isinstance(project_dir, str) else project_dir + + # Create a codeql database. Use a temporary directory if the user doesn't specify + if codeql_db is None: + db_path: TemporaryDirectory = TemporaryDirectory(delete=False, ignore_cleanup_errors=True) + codeql_db = db_path.name + # Since the user is not providing the codeql database path, we'll destroy the database at exit. + # TODO: this may be a potential gotcha. Is there a better solution here? + # TODO (BACKWARD COMPATIBILITY ISSUE): Only works on 3.12. + # If necessary, use shutil to handle this differently in 3.11 and below. + atexit.register(lambda: db_path.cleanup()) + # Also register the cleanup function for SIGINT and SIGTERM + signal.signal(signal.SIGINT, lambda *args, **kwargs: db_path.cleanup()) + signal.signal(signal.SIGTERM, lambda *args, **kwargs: db_path.cleanup()) + + codeql_db_create_cmd = shlex.split(f"codeql database create {codeql_db} --source-root={project_dir} --language=java",posix=False) + call = subprocess.Popen( + codeql_db_create_cmd, + stdout=subprocess.DEVNULL, + stderr=subprocess.PIPE, + ) + _, error = call.communicate() + if call.returncode != 0: + raise CodeQLDatabaseBuildException(f"Error building CodeQL database: {error.decode()}") + return Path(codeql_db) + + def _build_application_view(self) -> JApplication: + """ + Builds the application view of the java application. + + Returns + ------- + JApplication + The JApplication object representing the application view. + """ + application: JApplication = JApplication() + + # Lets build the class hierarchy tree first and store that information in the application object. + query = [] + + # Add import + query += ["import java"] + + # List classes and their superclasses (ignoring non-application classes and anonymous classes) + query += [ + "from Class cls", + "where cls.fromSource() and not cls.isAnonymous()", + "select cls, cls.getASupertype().getQualifiedName()", + ] + + # Execute the query using the CodeQLQueryRunner context manager + with CodeQLQueryRunner(self.db_path) as codeql_query: + class_superclass_pairs: DataFrame = codeql_query.execute( + query_string="\n".join(query), + column_names=["class", "superclass"], + ) + + application.cha = self.__process_class_hierarchy_pairs_to_tree(class_superclass_pairs) + return application + + @staticmethod + def __process_class_hierarchy_pairs_to_tree( + query_result: DataFrame, + ) -> DiGraph: + """ + Processes the query result into a directed graph representing the class hierarchy of the application. + + Parameters + ---------- + query_result : DataFrame + The result of the class hierarchy query. + + Returns + ------- + DiGraph + A directed graph representing the class hierarchy of the application. + """ + return nx.from_pandas_edgelist(query_result, "class", "superclass", create_using=nx.DiGraph()) + + def _build_call_graph(self) -> DiGraph: + """Builds the call graph of the application. + + Returns + ------- + DiGraph + A directed graph representing the call graph of the application. + """ + query = [] + + # Add import + query += ["import java"] + + # Add Call edges between caller and callee and filter to only capture application methods. + query += [ + "from Method caller, Method callee", + "where", + "caller.fromSource() and", + "callee.fromSource() and", + "caller.calls(callee)", + "select", + ] + + # Caller metadata + query += [ + "caller.getFile().getAbsolutePath(),", + '"[" + caller.getBody().getLocation().getStartLine() + ", " + caller.getBody().getLocation().getEndLine() + "]", //Caller body slice indices', + "caller.getQualifiedName(), // Caller's fullsignature", + "caller.getAModifier(), // caller's method modifier", + "caller.paramsString(), // caller's method parameter types", + "caller.getReturnType().toString(), // Caller's return type", + "caller.getDeclaringType().getQualifiedName(), // Caller's class", + "caller.getDeclaringType().getAModifier(), // Caller's class modifier", + ] + + # Callee metadata + query += [ + "callee.getFile().getAbsolutePath(),", + '"[" + callee.getBody().getLocation().getStartLine() + ", " + callee.getBody().getLocation().getEndLine() + "]", //Caller body slice indices', + "callee.getQualifiedName(), // Caller's fullsignature", + "callee.getAModifier(), // callee's method modifier", + "callee.paramsString(), // callee's method parameter types", + "callee.getReturnType().toString(), // Caller's return type", + "callee.getDeclaringType().getQualifiedName(), // Caller's class", + "callee.getDeclaringType().getAModifier() // Caller's class modifier", + ] + + query_string = "\n".join(query) + + # Execute the query using the CodeQLQueryRunner context manager + with CodeQLQueryRunner(self.db_path) as query: + query_result: DataFrame = query.execute( + query_string, + column_names=[ + # Caller Columns + "caller_file", + "caller_body_slice_index", + "caller_signature", + "caller_modifier", + "caller_params", + "caller_return_type", + "caller_class_signature", + "caller_class_modifier", + # Callee Columns + "callee_file", + "callee_body_slice_index", + "callee_signature", + "callee_modifier", + "callee_params", + "callee_return_type", + "callee_class_signature", + "callee_class_modifier", + ], + ) + + # Process the query results into JMethod instances + callgraph: DiGraph = self.__process_call_edges_to_callgraph(query_result) + return callable + + @staticmethod + def __process_call_edges_to_callgraph(query_result: DataFrame) -> DiGraph: + pass diff --git a/cldk/analysis/java/java.py b/cldk/analysis/java/java.py new file mode 100644 index 0000000..9403705 --- /dev/null +++ b/cldk/analysis/java/java.py @@ -0,0 +1,587 @@ +from pathlib import Path + +from typing import Dict, List, Tuple, Set +from networkx import DiGraph + +from cldk.analysis import SymbolTable, CallGraph +from cldk.models.java import JCallable +from cldk.models.java import JApplication +from cldk.models.java.models import JCompilationUnit, JMethodDetail, JType, JField +from cldk.analysis.java.codeanalyzer import JCodeanalyzer +from cldk.analysis.java.codeql import JCodeQL +from cldk.utils.analysis_engine import AnalysisEngine + + +class JavaAnalysis(SymbolTable, CallGraph): + + def __init__( + self, + project_dir: str | Path | None, + source_code: str | None, + analysis_backend: str, + analysis_backend_path: str | None, + analysis_json_path: str | Path | None, + analysis_level: str, + use_graalvm_binary: bool, + eager_analysis: bool, + ) -> None: + """ + Parameters + ---------- + project_dir : str + The directory path of the project. + analysis_backend : str, optional + The analysis_backend used for analysis, defaults to "codeql". + analysis_backend_path : str, optional + The path to the analysis_backend, defaults to None and in the case of codeql, it is assumed that the cli is installed and + available in the PATH. In the case of codeanalyzer the codeanalyzer.jar is downloaded from the lastest release. + analysis_json_path : str or Path, optional + The path save the to the analysis database (analysis.json), defaults to None. If None, the analysis database is + not persisted. + use_graalvm_binary : bool, optional + A flag indicating whether to use the GraalVM binary for SDG analysis, defaults to False. If False, the default + Java binary is used and one needs to have Java 17 or higher installed. + eager_analysis : bool, optional + A flag indicating whether to perform eager analysis, defaults to False. If True, the analysis is performed + eagerly. That is, the analysis.json file is created during analysis every time even if it already exists. + + Attributes + ---------- + analysis_backend : JCodeQL | JApplication + The analysis_backend used for analysis. + application : JApplication + The application view of the Java code. + """ + self.project_dir = project_dir + self.source_code = source_code + self.analysis_level = analysis_level + self.analysis_json_path = analysis_json_path + self.analysis_backend_path = analysis_backend_path + self.eager_analysis = eager_analysis + self.use_graalvm_binary = use_graalvm_binary + + # Initialize the analysis analysis_backend + if analysis_backend.lower() == "codeql": + self.analysis_backend: JCodeQL = JCodeQL(self.project_dir, self.analysis_json_path) + elif analysis_backend.lower() == "codeanalyzer": + self.analysis_backend: JCodeanalyzer = JCodeanalyzer( + project_dir=self.project_dir, + source_code=self.source_code, + eager_analysis=self.eager_analysis, + analysis_level=self.analysis_level, + analysis_json_path=self.analysis_json_path, + use_graalvm_binary=self.use_graalvm_binary, + analysis_backend_path=self.analysis_backend_path, + ) + else: + raise NotImplementedError(f"Support for {analysis_backend} has not been implemented yet.") + + def get_imports(self) -> List[str]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_variables(self, **kwargs): + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_service_entry_point_classes(self, **kwargs): + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_service_entry_point_methods(self, **kwargs): + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_application_view(self) -> JApplication: + """ + Returns the application view of the Java code. + + Returns: + -------- + JApplication + The application view of the Java code. + """ + if self.source_code: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_application_view() + + def get_symbol_table(self) -> Dict[str, JCompilationUnit]: + """ + Returns the symbol table of the Java code. + + Returns: + -------- + Dict[str, JCompilationUnit] + The application view of the Java code. + """ + return self.analysis_backend.get_symbol_table() + + def get_compilation_units(self) -> List[JCompilationUnit]: + """ + Returns the compilation units of the Java code. + + Returns + ------- + Dict[str, JCompilationUnit] + The compilation units of the Java code. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_compilation_units() + + def get_class_hierarchy(self) -> DiGraph: + """ + Returns the class hierarchy of the Java code. + + Returns: + -------- + DiGraph + The class hierarchy of the Java code. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + raise NotImplementedError("Class hierarchy is not implemented yet.") + + def get_call_graph(self) -> DiGraph: + """ + Returns the call graph of the Java code. + + Returns: + -------- + DiGraph + The call graph of the Java code. + """ + return self.analysis_backend.get_call_graph() + + def get_call_graph_json(self) -> str: + """ + serialize callgraph to json + """ + if self.source_code: + raise NotImplementedError("Producing a call graph over a single file is not implemented yet.") + return self.analysis_backend.get_call_graph_json() + + def get_callers(self, target_class_name: str, target_method_declaration: str): + """ + Get all the caller details for a given java method. + + Returns: + -------- + Dict + Caller details in a dictionary. + """ + if self.source_code: + raise NotImplementedError("Generating all callers over a single file is not implemented yet.") + return self.analysis_backend.get_all_callers(target_class_name, target_method_declaration) + + def get_callees(self, source_class_name: str, source_method_declaration: str): + """ + Get all the callee details for a given java method. + + Returns: + -------- + Dict + Callee details in a dictionary. + """ + if self.source_code: + raise NotImplementedError("Generating all callees over a single file is not implemented yet.") + return self.analysis_backend.get_all_callees(source_class_name, source_method_declaration) + + def get_methods(self) -> Dict[str, Dict[str, JCallable]]: + """ + Returns a dictionary of all methods in the Java code with + qualified class name as key and dictionary of methods in that class + as value + + Returns: + -------- + Dict[str, Dict[str, JCallable]]: + A dictionary of dictionaries of all methods in the Java code. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_methods_in_application() + + def get_classes(self) -> Dict[str, JType]: + """ + Returns a dictionary of all classes in the Java code. + + Returns: + -------- + Dict[str, JType] + A dict of all classes in the Java code, with qualified class names as keys + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_classes() + + def get_classes_by_criteria(self, inclusions=None, exclusions=None) -> Dict[str, JType]: + """ + Returns a dictionary of all classes in the Java code. + + #TODO: Document the parameters inclusions and exclusions. + + Returns: + -------- + Dict[str, JType] + A dict of all classes in the Java code, with qualified class names as keys + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + if exclusions is None: + exclusions = [] + if inclusions is None: + inclusions = [] + class_dict: Dict[str, JType] = {} + all_classes = self.get_all_classes() + for application_class in all_classes: + is_selected = False + for inclusion in inclusions: + if inclusion in application_class: + is_selected = True + + for exclusion in exclusions: + if exclusion in application_class: + is_selected = False + if is_selected: + class_dict[application_class] = all_classes[application_class] + return class_dict + + def get_class(self, qualified_class_name) -> JType: + """ + Returns a class given qualified class name. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + JType + A class for the given qualified class name. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_class(qualified_class_name) + + def get_method(self, qualified_class_name, qualified_method_name) -> JCallable: + """ + Returns a method given qualified method name. + + Parameters: + ----------- + qualified_method_name : str + The qualified name of the method. + + Returns: + -------- + JCallable + A method for the given qualified method name. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_method(qualified_class_name, qualified_method_name) + + def get_java_file(self, qualified_class_name) -> str: + """ + Returns a class given qualified class name. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + str + Java file name containing the given qualified class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_java_file(qualified_class_name) + + def get_java_compilation_unit(self, file_path: str) -> JCompilationUnit: + """ + Given the path of a Java source file, returns the compilation unit object from the symbol table. + + Parameters + ---------- + file_path : str + Absolute path to Java source file + + Returns + ------- + JCompilationUnit + Compilation unit object for Java source file + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_java_compilation_unit(file_path) + + def get_methods_in_class(self, qualified_class_name) -> Dict[str, JCallable]: + """ + Returns a dictionary of all methods in the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + Dict[str, JCallable] + A dictionary of all methods in the given class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_methods_in_class(qualified_class_name) + + def get_constructors(self, qualified_class_name) -> Dict[str, JCallable]: + """ + Returns a dictionary of all constructors of the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + Dict[str, JCallable] + A dictionary of all constructors of the given class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_constructors(qualified_class_name) + + def get_fields(self, qualified_class_name) -> List[JField]: + """ + Returns a list of all fields of the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JField] + A list of all fields of the given class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_fields(qualified_class_name) + + def get_nested_classes(self, qualified_class_name) -> List[JType]: + """ + Returns a list of all nested classes for the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JType] + A list of nested classes for the given class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_nested_classes(qualified_class_name) + + def get_sub_classes(self, qualified_class_name) -> Dict[str, JType]: + """ + Returns a dictionary of all sub-classes of the given class + Parameters + ---------- + qualified_class_name + + Returns + ------- + Dict[str, JType]: A dictionary of all sub-classes of the given class, and class details + """ + return self.analysis_backend.get_all_sub_classes(qualified_class_name=qualified_class_name) + + def get_extended_classes(self, qualified_class_name) -> List[str]: + """ + Returns a list of all extended classes for the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JType] + A list of extended classes for the given class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_extended_classes(qualified_class_name) + + def get_implemented_interfaces(self, qualified_class_name) -> List[str]: + """ + Returns a list of all implemented interfaces for the given class. + + Parameters: + ----------- + qualified_class_name : str + The qualified name of the class. + + Returns: + -------- + List[JType] + A list of implemented interfaces for the given class. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_implemented_interfaces(qualified_class_name) + + def get_class_call_graph(self, qualified_class_name: str, method_name: str | None = None) -> (List)[Tuple[JMethodDetail, JMethodDetail]]: + """ + A call graph for a given class and (optionally) a given method. + + Parameters + ---------- + qualified_class_name : str + The qualified name of the class. + method_name : str, optional + The name of the method in the class. + + Returns + ------- + List[Tuple[JMethodDetail, JMethodDetail]] + An edge list of the call graph for the given class and method. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_class_call_graph(qualified_class_name, method_name) + + def get_entry_point_classes(self) -> Dict[str, JType]: + """ + Returns a dictionary of all entry point classes in the Java code. + + Returns: + -------- + Dict[str, JType] + A dict of all entry point classes in the Java code, with qualified class names as keys + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_entry_point_classes() + + def get_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]: + """ + Returns a dictionary of all entry point methods in the Java code with + qualified class name as key and dictionary of methods in that class + as value + + Returns: + -------- + Dict[str, Dict[str, JCallable]]: + A dictionary of dictionaries of entry point methods in the Java code. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_all_entry_point_methods() + + def remove_all_comments(self) -> str: + """ + Remove all comments from the source code. + + Parameters + ---------- + source_code : str + The source code to process. + + Returns + ------- + str + The source code with all comments removed. + """ + + # Remove any prefix comments/content before the package declaration + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.CODEANALYZER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.remove_all_comments(self.source_code) + + def get_methods_with_annotations(self, annotations: List[str]) -> Dict[str, List[Dict]]: + """ + Returns a dictionary of method names and method bodies. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + annotations : List[str] + List of annotation strings. + Returns: + -------- + Dict[str,List[Dict]] + Dictionary with annotations as keys and + a list of dictionaries containing method names and bodies, as values. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.CODEANALYZER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_methods_with_annotations(self.source_code, annotations) + + def get_test_methods(self, source_class_code: str) -> Dict[str, str]: + """ + Returns a dictionary of method names and method bodies. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + Returns: + -------- + Dict[str,str] + Dictionary of method names and method bodies. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.CODEANALYZER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_test_methods(self.source_code) + + def get_calling_lines(self, target_method_name: str) -> List[int]: + """ + Returns a list of line numbers in source method where target method is called. + + Parameters: + ----------- + source_method_code : str + source method code + + target_method_code : str + target method code + + Returns: + -------- + List[int] + List of line numbers within in source method code block. + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_calling_lines(self.source_code, target_method_name) + + def get_call_targets(self, declared_methods: dict) -> Set[str]: + """Generate a list of call targets from the method body. + + Uses simple name resolution for finding the call targets. Nothing sophiscticed here. Just a simple search + over the AST. + + Parameters + ---------- + method_body : Node + The method body. + declared_methods : dict + A dictionary of all declared methods in the class. + + Returns + ------- + List[str] + A list of call targets (methods). + """ + if self.analysis_backend in [AnalysisEngine.CODEQL, AnalysisEngine.TREESITTER]: + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + return self.analysis_backend.get_call_targets(self.source_code, declared_methods) diff --git a/cldk/analysis/java/treesitter/__init__.py b/cldk/analysis/java/treesitter/__init__.py new file mode 100644 index 0000000..19d070b --- /dev/null +++ b/cldk/analysis/java/treesitter/__init__.py @@ -0,0 +1,2 @@ +from cldk.analysis.java.treesitter.javasitter import JavaSitter +__all__ = ["JavaSitter"] \ No newline at end of file diff --git a/cldk/analysis/java/treesitter/javasitter.py b/cldk/analysis/java/treesitter/javasitter.py new file mode 100644 index 0000000..55fa3d8 --- /dev/null +++ b/cldk/analysis/java/treesitter/javasitter.py @@ -0,0 +1,359 @@ +from itertools import groupby +from typing import List, Set, Dict +from tree_sitter import Language, Node, Parser, Query +import tree_sitter_java as tsjava +from cldk.models.treesitter import Captures + + +class JavaSitter: + """ + Treesitter for Java usecases. + """ + + def __init__(self) -> None: + self.language: Language = Language(tsjava.language()) + self.parser: Parser = Parser(self.language) + + def method_is_not_in_class(self, method_name: str, class_body: str) -> bool: + """Check if a method is in a class. + + Parameters + ---------- + method_name : str + The name of the method to check. + class_body : str + The body of the class to check. + + Returns + ------- + bool + True if the method is in the class, False otherwise. + """ + methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", class_body) + + return method_name not in {method.node.text.decode() for method in methods_in_class} + + def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: + """Frame a query for the tree-sitter parser. + + Parameters + ---------- + query : str + The query to frame. + code_to_process : str + The code to process. + """ + framed_query: Query = self.language.query(query) + tree = self.parser.parse(bytes(code_to_process, "utf-8")) + return Captures(framed_query.captures(tree.root_node)) + + def get_method_name_from_declaration(self, method_name_string: str) -> str: + """Get the method name from the method signature.""" + captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", method_name_string) + + return captures[0].node.text.decode() + + def get_method_name_from_invocation(self, method_invocation: str) -> str: + """ + Using the tree-sitter query, extract the method name from the method invocation. + """ + + captures: Captures = self.frame_query_and_capture_output("(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation) + return captures[0].node.text.decode() + + def safe_ascend(self, node: Node, ascend_count: int) -> Node: + """Safely ascend the tree. If the node does not exist or if it has no parent, raise an error. + + Parameters + ---------- + node : Node + The node to ascend from. + ascend_count : int + The number of times to ascend the tree. + + Returns + ------- + Node + The node at the specified level of the tree. + + Raises + ------ + ValueError + If the node has no parent. + """ + if node is None: + raise ValueError("Node does not exist.") + if node.parent is None: + raise ValueError("Node has no parent.") + if ascend_count == 0: + return node + else: + return self.safe_ascend(node.parent, ascend_count - 1) + + def get_call_targets(self, method_body: str, declared_methods: dict) -> Set[str]: + """Generate a list of call targets from the method body. + + Uses simple name resolution for finding the call targets. Nothing sophiscticed here. Just a simple search + over the AST. + + Parameters + ---------- + method_body : Node + The method body. + declared_methods : dict + A dictionary of all declared methods in the class. + + Returns + ------- + List[str] + A list of call targets (methods). + """ + + select_test_method_query = "(method_invocation name: (identifier) @method)" + captures: Captures = self.frame_query_and_capture_output(select_test_method_query, method_body) + + call_targets = set( + map( + # x is a capture, x.node is the node, x.node.text is the text of the node (in this case, the method + # name) + lambda x: x.node.text.decode(), + filter( # Filter out the calls to methods that are not declared in the class + lambda capture: capture.node.text.decode() in declared_methods, + captures, + ), + ) + ) + return call_targets + + def get_calling_lines(self, source_method_code: str, target_method_name: str) -> List[int]: + """ + Returns a list of line numbers in source method where target method is called. + + Parameters: + ----------- + source_method_code : str + source method code + + target_method_code : str + target method code + + Returns: + -------- + List[int] + List of line numbers within in source method code block. + """ + query = "(method_invocation name: (identifier) @method_name)" + # if target_method_name is a method signature, get the method name + # if it is not a signature, we will just keep the passed method name + try: + target_method_name = self.get_method_name_from_declaration(target_method_name) + except Exception: + pass + + captures: Captures = self.frame_query_and_capture_output(query, source_method_code) + # Find the line numbers where target method calls happen in source method + target_call_lines = [] + for c in captures: + method_name = c.node.text.decode() + if method_name == target_method_name: + target_call_lines.append(c.node.start_point[0]) + return target_call_lines + + def get_test_methods(self, source_class_code: str) -> Dict[str, str]: + """ + Returns a dictionary of method names and method bodies. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + Returns: + -------- + Dict[str,str] + Dictionary of method names and method bodies. + """ + query = """ + (method_declaration + (modifiers + (marker_annotation + name: (identifier) @annotation) + ) + ) + """ + + captures: Captures = self.frame_query_and_capture_output(query, source_class_code) + test_method_dict = {} + for capture in captures: + if capture.name == "annotation": + if capture.node.text.decode() == "Test": + method_node = self.safe_ascend(capture.node, 3) + method_name = method_node.children[2].text.decode() + test_method_dict[method_name] = method_node.text.decode() + return test_method_dict + + def get_methods_with_annotations(self, source_class_code: str, annotations: List[str]) -> Dict[str, List[Dict]]: + """ + Returns a dictionary of method names and method bodies. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + annotations : List[str] + List of annotation strings. + Returns: + -------- + Dict[str,List[Dict]] + Dictionary with annotations as keys and + a list of dictionaries containing method names and bodies, as values. + """ + query = """ + (method_declaration + (modifiers + (marker_annotation + name: (identifier) @annotation) + ) + ) + """ + captures: Captures = self.frame_query_and_capture_output(query, source_class_code) + annotation_method_dict = {} + for capture in captures: + if capture.name == "annotation": + annotation = capture.node.text.decode() + if annotation in annotations: + method = {} + method_node = self.safe_ascend(capture.node, 3) + method["body"] = method_node.text.decode() + method["method_name"] = method_node.children[2].text.decode() + if annotation in annotation_method_dict.keys(): + annotation_method_dict[annotation].append(method) + else: + annotation_method_dict[annotation] = [method] + return annotation_method_dict + + def get_all_type_invocations(self, source_code: str) -> Set[str]: + """ + Given the source code, get all the type invocations in the source code. + + Parameters + ---------- + source_code : str + The source code to process. + + Returns + ------- + Set[str] + A set of all the type invocations in the source code. + """ + type_references: Captures = self.frame_query_and_capture_output("(type_identifier) @type_id", source_code) + return {type_id.node.text.decode() for type_id in type_references} + + def get_lexical_tokens(self, code: str, filter_by_node_type: List[str] | None = None) -> List[str]: + """ + Get the lexical tokens given the code + + Parameters + ---------- + filter_by_node_type: If needed, filter the type of the node + code: Java code + + Returns + ------- + List[str] + List of lexical tokens + + """ + tree = self.parser.parse(bytes(code, "utf-8")) + root_node = tree.root_node + lexical_tokens = [] + + def collect_leaf_token_values(node): + if len(node.children) == 0: + if filter_by_node_type is not None: + if node.type in filter_by_node_type: + lexical_tokens.append(code[node.start_byte : node.end_byte]) + else: + lexical_tokens.append(code[node.start_byte : node.end_byte]) + else: + for child in node.children: + collect_leaf_token_values(child) + + collect_leaf_token_values(root_node) + return lexical_tokens + + def remove_all_comments(self, source_code: str) -> str: + """ + Remove all comments from the source code. + + Parameters + ---------- + source_code : str + The source code to process. + + Returns + ------- + str + The source code with all comments removed. + """ + + # Remove any prefix comments/content before the package declaration + lines_of_code = source_code.split("\n") + for i, line in enumerate(lines_of_code): + if line.strip().startswith("package"): + break + + source_code = "\n".join(lines_of_code[i:]) + + pruned_source_code = self.make_pruned_code_prettier(source_code) + + # Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments). + comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=source_code) + + comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", code_to_process=source_code) + + for capture in comment_blocks: + pruned_source_code = pruned_source_code.replace(capture.node.text.decode(), "") + + for capture in comment_lines: + pruned_source_code = pruned_source_code.replace(capture.node.text.decode(), "") + + return self.make_pruned_code_prettier(pruned_source_code) + + def make_pruned_code_prettier(self, pruned_code: str) -> str: + """Make the pruned code prettier. + + Parameters + ---------- + pruned_code : str + The pruned code. + + Returns + ------- + str + The prettified pruned code. + """ + # First remove remaining block comments + block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=pruned_code) + + for capture in block_comments: + pruned_code = pruned_code.replace(capture.node.text.decode(), "") + + # Split the source code into lines and remove trailing whitespaces. rstip() removes the trailing whitespaces. + new_source_code_as_list = list(map(lambda x: x.rstrip(), pruned_code.split("\n"))) + + # Remove all comment lines. In java the comment lines start with / (for // and /*) or * (for multiline + # comments). + new_source_code_as_list = [line for line in new_source_code_as_list if not line.lstrip().startswith(("/", "*"))] + + # Remove multiple contiguous empty lines. This is done using the groupby function from itertools. + # groupby returns a list of tuples where the first element is the key and the second is an iterator over the + # group. We only need the key, so we take the first element of each tuple. The iterator is essentially a + # generator that contains the elements of the group. We don't need it, so we discard it. The key is the line + # itself, so we can use it to remove contiguous empty lines. + new_source_code_as_list = [key for key, _ in groupby(new_source_code_as_list)] + + # Join the lines back together + prettified_pruned_code = "\n".join(new_source_code_as_list) + + return prettified_pruned_code.strip() diff --git a/cldk/analysis/javascript/__init__.py b/cldk/analysis/javascript/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/analysis/javascript/treesitter/__init__.py b/cldk/analysis/javascript/treesitter/__init__.py new file mode 100644 index 0000000..365fb7f --- /dev/null +++ b/cldk/analysis/javascript/treesitter/__init__.py @@ -0,0 +1,3 @@ +from cldk.analysis.javascript.treesitter.javascript_sitter import JavascriptSitter + +__all__ = ["JavascriptSitter"] diff --git a/cldk/analysis/javascript/treesitter/javascript_sitter.py b/cldk/analysis/javascript/treesitter/javascript_sitter.py new file mode 100644 index 0000000..65d0671 --- /dev/null +++ b/cldk/analysis/javascript/treesitter/javascript_sitter.py @@ -0,0 +1,457 @@ +from typing import List, Optional +from tree_sitter import Language, Parser, Query, Node +import tree_sitter_javascript as tsjavascript + +from cldk.models.javascript.models import JsCallable, JsClass, JsParameter, JsProgram +from cldk.models.treesitter import Captures + + +class JavascriptSitter: + """ + Tree sitter for Javascript use cases. + """ + + def __init__(self) -> None: + self.language: Language = Language(tsjavascript.language()) + self.parser: Parser = Parser(self.language) + + def get_all_functions(self, code: str) -> List[JsCallable]: + """ + Get all the functions and methods in the provided code. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsCallable] + All the function and method details within the provided code. + """ + + callables: List[JsCallable] = [] + classes: List[JsClass] = self.get_classes(code) + for clazz in classes: + callables.extend(clazz.methods) + + callables.extend(self.__get_top_level_functions(code)) + callables.extend(self.__get_top_level_generators(code)) + callables.extend(self.__get_top_level_arrow_functions(code)) + callables.extend(self.__get_top_level_function_expressions(code)) + callables.extend(self.__get_top_level_generator_expressions(code)) + + return callables + + def get_classes(self, code: str) -> List[JsClass]: + """ + Get the classes in the provided code. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsClass] + All class details within the provided code. + """ + + query = """(class_declaration) @class""" + + return [self.__get_class_details(capture.node) for capture in self.__frame_query_and_capture_output(query, code)] + + def get_program_details(self, source_file: str) -> JsProgram: + """ + Get the details of the provided code file. + + Parameters + ---------- + source_file : str + The code we want to analyse. + + Returns + ------- + The details of the provided file. + """ + + return JsProgram( + classes=self.get_classes(source_file), + callables=self.get_all_functions(source_file), + ) + + def __frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: + """Frame a query for the tree-sitter parser. + + Parameters + ---------- + query : str + The query to frame. + code_to_process : str + The code to process. + + Returns + ------- + Captures + The list of tree-sitter captures. + """ + + framed_query: Query = self.language.query(query) + tree = self.parser.parse(bytes(code_to_process, "utf-8")) + return Captures(framed_query.captures(tree.root_node)) + + def __get_class_details(self, node: Node) -> JsClass: + """ + Get the classe details for a provided tree-sitter class node. + + Parameters + ---------- + node : Node + The tree-sitter class node whose details we want. + + Returns + ------- + JsClass + The class details of the provided node. + """ + + parent_node: Node = self.__get_class_parent_node(node) + + return JsClass( + name=node.child_by_field_name("name").text.decode(), + methods=self.__get_methods(node), + start_line=node.start_point[0], + end_line=node.end_point[0], + # TODO: needs more refinement since you can have more than an identifier + parent=parent_node.named_children[0].text.decode() if parent_node else None, + ) + + def __get_methods(self, class_node: Node) -> List[JsCallable]: + """ + Get the methods for a provided tree-sitter class node. + + Parameters + ---------- + class_node : Node + The tree-sitter class node whose methods we want. + + Returns + ------- + List[JsCallable] + The method details of the provided class node. + """ + + class_body_node = class_node.child_by_field_name("body") + + return [self.__extract_function_details(child) for child in class_body_node.children if child.type == "method_definition"] + + def __get_top_level_functions(self, code: str) -> List[JsCallable]: + """ + Get the exportable functions from the provided code. + There is no guarantee that the functions are exported. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsCallable] + The function details within the provided code. + """ + + query = """ + (program [ + (function_declaration) @function + (export_statement (function_declaration) @function.export) + ]) + """ + captures: Captures = self.__frame_query_and_capture_output(query, code) + + return [self.__extract_function_details(capture.node) for capture in captures] + + def __get_top_level_generators(self, code: str) -> List[JsCallable]: + """ + Get the exportable generator functions from the provided code. + There is no guarantee that the functions are exported. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsCallable] + The generator function details within the provided code. + """ + + query = """ + (program [ + (generator_function_declaration) @generator + (export_statement (generator_function_declaration) @generator.export) + ]) + """ + captures: Captures = self.__frame_query_and_capture_output(query, code) + + return [self.__extract_function_details(capture.node) for capture in captures] + + def __get_top_level_arrow_functions(self, code: str) -> List[JsCallable]: + """ + Get the exportable arrow functions from the provided code. + There is no guarantee that the functions are exported. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsCallable] + The arrow function details within the provided code. + """ + + # get arrow functions that can be called from an external file. + query = """ + (program [ + (expression_statement (assignment_expression (arrow_function) @arrow.assignment)) + (expression_statement (sequence_expression (assignment_expression (arrow_function) @arrow.assignment))) + (lexical_declaration (variable_declarator (arrow_function) @arrow.variable)) + (export_statement (arrow_function) @arrow.export) + (export_statement (lexical_declaration (variable_declarator (arrow_function) @arrow.export.variable))) + ]) + """ + + captures: Captures = self.__frame_query_and_capture_output(query, code) + callables: List[JsCallable] = [self.__extract_arrow_function_details(capture.node, capture.name) for capture in captures] + + return [callable for callable in callables if callable.name] + + def __get_top_level_function_expressions(self, code: str) -> List[JsCallable]: + """ + Get the exportable function expressions from the provided code. + There is no guarantee that the functions are exported. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsCallable] + The function expression details within the provided code. + """ + + # get function expressions that can be called from an external file. + # TODO: function node changed to function_expression in newer tree-sitter versions + query = """ + (program [ + (expression_statement (assignment_expression (function) @function.assignment)) + (expression_statement (sequence_expression (assignment_expression (function) @function.assignment))) + (lexical_declaration (variable_declarator (function) @function.variable)) + (export_statement (function) @function.export) + (export_statement (lexical_declaration (variable_declarator (function) @function.export.variable))) + ]) + """ + + captures: Captures = self.__frame_query_and_capture_output(query, code) + + return [self.__extract_function_expression_details(capture.node, capture.name) for capture in captures] + + def __get_top_level_generator_expressions(self, code: str) -> List[JsCallable]: + """ + Get the exportable generator expressions from the provided code. + There is no guarantee that the functions are exported. + + Parameters + ---------- + code : str + The code you want to analyse. + + Returns + ------- + List[JsCallable] + The generator expression details within the provided code. + """ + + # get generator expressions that can be called from an external file. + query = """ + (program [ + (expression_statement (assignment_expression (generator_function) @function.assignment)) + (expression_statement (sequence_expression (assignment_expression (generator_function) @function.assignment))) + (lexical_declaration (variable_declarator (generator_function) @function.variable)) + (export_statement (generator_function) @function.export) + (export_statement (lexical_declaration (variable_declarator (generator_function) @function.export.variable))) + ]) + """ + + captures: Captures = self.__frame_query_and_capture_output(query, code) + + return [self.__extract_function_expression_details(capture.node, capture.name) for capture in captures] + + def __extract_function_details(self, function_node: Node) -> JsCallable: + """ + Extract the details from a function/method tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter function node whose details we want. + capture_name : str + The identifier used to extract the node. + + Returns + ------- + JsCallable + The function details. + """ + + name: str = function_node.child_by_field_name("name").text.decode() + parameters_node: Node = function_node.child_by_field_name("parameters") + + return JsCallable( + name=name, + code=function_node.text.decode(), + paremeters=self.__extract_parameters_details(parameters_node), + signature=name + parameters_node.text.decode(), + start_line=function_node.start_point[0], + end_line=function_node.end_point[0], + is_constructor=function_node.type == "method_definition" and name == "constructor", + ) + + def __extract_arrow_function_details(self, node: Node, capture_name: str) -> JsCallable: + """ + Extract the details from an arrow function tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter arrow function node whose details we want. + capture_name : str + The identifier used to extract the node. + + Returns + ------- + JsCallable + The function details. + """ + + name: str = None + if capture_name == "arrow.assignment": + left_node = node.parent.child_by_field_name("left") + if left_node.type == "identifier": + name = left_node.text.decode() + elif capture_name == "arrow.export": + name = "default" + else: + name_node = node.parent.child_by_field_name("name") + name = name_node.text.decode() + + parameter_node: Node = node.child_by_field_name("parameter") + parameters_node: Node = node.child_by_field_name("parameters") + + # TODO: not sure about this + parameters_text = f"({parameter_node.text.decode()})" if parameter_node else parameters_node.text.decode() + signature: str = (name if name else "") + parameters_text + + return JsCallable( + name=name, + code=node.text.decode(), + paremeters=[JsParameter(name=parameter_node.text.decode())] if parameter_node else self.__extract_parameters_details(parameters_node), + signature=signature, + start_line=node.start_point[0], + end_line=node.end_point[0], + ) + + def __extract_function_expression_details(self, node: Node, capture_name: str) -> JsCallable: + """ + Extract the details from a function expression tree-sitter node. + + Parameters + ---------- + node : Node + The tree-sitter function node whose details we want. + capture_name : str + The identifier used to extract the node. + + Returns + ------- + JsCallable + The function details. + """ + + name: str = None + if capture_name == "function.assignment": + left_node = node.parent.child_by_field_name("left") + if left_node.type == "identifier": + name = left_node.text.decode() + elif capture_name == "function.export": + name = "default" + else: + name_node = node.parent.child_by_field_name("name") + name = name_node.text.decode() + + parameters_node: Node = node.child_by_field_name("parameters") + + # TODO: not sure about this + signature: str = (name if name else "") + parameters_node.text.decode() + + return JsCallable( + name=name, + code=node.text.decode(), + paremeters=self.__extract_parameters_details(parameters_node), + signature=signature, + start_line=node.start_point[0], + end_line=node.end_point[0], + ) + + def __extract_parameters_details(self, parameters_node: Node) -> List[JsParameter]: + """ + Extract the parameter details from a given tree-sitter parameters node. + + Parameters + ---------- + parameters_node : Node + The tree-sitter parameters node whose details we want. + + Returns + ------- + List[JsParameter] + The list of parameter details. + """ + + if not parameters_node or not parameters_node.children: + return [] + + parameters: List[JsParameter] = [] + for child in parameters_node.children: + # TODO incomplete, needs a recursive way of finding the parameters + if child.type in ["identifier", "undefined"]: + parameters.append(JsParameter(name=child.text.decode())) + + return parameters + + def __get_class_parent_node(self, class_node: Node) -> Optional[Node]: + """ + Extracts the tree-sitter heritage node, if it exists from a class node. + + Parameters + ---------- + class_node : Node + The tree-sitter class node we want to process. + + Returns + ------- + Optional[Node] + The tree-sitter node that has the heritage data for the provided class node. + """ + + for child in class_node.children: + if child.type == "class_heritage": + return child + + return None diff --git a/cldk/analysis/program_dependence_graph.py b/cldk/analysis/program_dependence_graph.py new file mode 100644 index 0000000..855e9f5 --- /dev/null +++ b/cldk/analysis/program_dependence_graph.py @@ -0,0 +1,6 @@ +from abc import ABC, abstractmethod + + +class ProgramDependenceGraph(ABC): + def __init__(self) -> None: + super().__init__() diff --git a/cldk/analysis/python/__init__.py b/cldk/analysis/python/__init__.py new file mode 100644 index 0000000..32097e0 --- /dev/null +++ b/cldk/analysis/python/__init__.py @@ -0,0 +1,3 @@ +from .python import PythonAnalysis + +__all__ = ["PythonAnalysis"] \ No newline at end of file diff --git a/cldk/analysis/python/python.py b/cldk/analysis/python/python.py new file mode 100644 index 0000000..a38f2b5 --- /dev/null +++ b/cldk/analysis/python/python.py @@ -0,0 +1,122 @@ +from abc import ABC +from pathlib import Path +from typing import Dict, List +from pandas import DataFrame + +from cldk.analysis import SymbolTable +from cldk.analysis.python.treesitter import PythonSitter +from cldk.models.python.models import PyMethod, PyImport, PyModule, PyClass + + +class PythonAnalysis(SymbolTable): + def __init__( + self, + analysis_backend: str, + eager_analysis: bool, + project_dir: str | Path | None, + source_code: str | None, + analysis_backend_path: str | None, + analysis_json_path: str | Path | None, + use_graalvm_binary: bool = None, + ) -> None: + self.project_dir = project_dir + self.source_code = source_code + self.analysis_json_path = analysis_json_path + self.analysis_backend_path = analysis_backend_path + self.eager_analysis = eager_analysis + self.use_graalvm_binary = use_graalvm_binary + + # Initialize the analysis analysis_backend + if analysis_backend.lower() == "codeql": + raise NotImplementedError(f"Support for {analysis_backend} has not been implemented yet.") + elif analysis_backend.lower() == "codeanalyzer": + raise NotImplementedError(f"Support for {analysis_backend} has not been implemented yet.") + elif analysis_backend.lower() == "treesitter": + self.analysis_backend: PythonSitter = PythonSitter() + else: + raise NotImplementedError(f"Support for {analysis_backend} has not been implemented yet.") + + def get_methods(self) -> List[PyMethod]: + """ + Given an application or a source code, get all the methods + """ + return self.analysis_backend.get_all_methods(self.source_code) + + def get_functions(self) -> List[PyMethod]: + """ + Given an application or a source code, get all the methods + """ + return self.analysis_backend.get_all_functions(self.source_code) + + def get_modules(self) -> List[PyModule]: + """ + Given the project directory, get all the modules + """ + return self.analysis_backend.get_all_modules(self.project_dir) + + def get_method_details(self, method_signature: str) -> PyMethod: + """ + Given the code body and the method signature, returns the method details related to that method + Parameters + ---------- + method_signature: method signature + + Returns + ------- + PyMethod: Returns the method details related to that method + """ + return self.analysis_backend.get_method_details(self.source_code, method_signature) + + def get_imports(self) -> List[PyImport]: + """ + Given an application or a source code, get all the imports + """ + return self.analysis_backend.get_all_imports_details(self.source_code) + + def get_variables(self, **kwargs): + """ + Given an application or a source code, get all the variables + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_classes(self) -> List[PyClass]: + """ + Given an application or a source code, get all the classes + """ + return self.analysis_backend.get_all_classes(self.source_code) + + def get_classes_by_criteria(self, **kwargs): + """ + Given an application or a source code, get all the classes given the inclusion and exclution criteria + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_sub_classes(self, **kwargs): + """ + Given an application or a source code, get all the sub-classes + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_nested_classes(self, **kwargs): + """ + Given an application or a source code, get all the nested classes + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_constructors(self, **kwargs): + """ + Given an application or a source code, get all the constructors + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_methods_in_class(self, **kwargs): + """ + Given an application or a source code, get all the methods within the given class + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") + + def get_fields(self, **kwargs): + """ + Given an application or a source code, get all the fields + """ + raise NotImplementedError(f"Support for this functionality has not been implemented yet.") diff --git a/cldk/analysis/python/treesitter/__init__.py b/cldk/analysis/python/treesitter/__init__.py new file mode 100644 index 0000000..44b9d0c --- /dev/null +++ b/cldk/analysis/python/treesitter/__init__.py @@ -0,0 +1,3 @@ +from cldk.analysis.python.treesitter.python_sitter import PythonSitter + +__all__ = ["PythonSitter"] diff --git a/cldk/analysis/python/treesitter/python_sitter.py b/cldk/analysis/python/treesitter/python_sitter.py new file mode 100644 index 0000000..3460987 --- /dev/null +++ b/cldk/analysis/python/treesitter/python_sitter.py @@ -0,0 +1,350 @@ +import glob +import os +from pathlib import Path +from typing import List + +from sphinx.domains.python import PyField +from tree_sitter import Language, Parser, Query, Node +import tree_sitter_python as tspython + +from cldk.models.python.models import PyMethod, PyClass, PyArg, PyImport, PyModule, PyCallSite +from cldk.models.treesitter import Captures + + +class PythonSitter: + """ + Tree sitter for Python use cases. + """ + + def __init__(self) -> None: + self.language: Language = Language(tspython.language()) + self.parser: Parser = Parser(self.language) + + def get_all_methods(self, module: str) -> List[PyMethod]: + """ + Get all the methods in the specific module. + Parameters + ---------- + module: code body of the module + + Returns + ------- + List[PyMethod]: returns all the method details within the module + + """ + methods: List[PyMethod] = [] + method_signatures: dict[str, List[int]] = {} + # Get the methods declared under class + all_class_details: List[PyClass] = self.get_all_classes(module=module) + for class_details in all_class_details: + for method in class_details.methods: + method_signatures[method.full_signature] = [method.start_line, method.end_line] + methods.extend(class_details.methods) + return methods + + def get_all_functions(self, module: str) -> List[PyMethod]: + """ + Get all the methods in the specific module. + Parameters + ---------- + module: code body of the module + + Returns + ------- + List[PyMethod]: returns all the method details within the module + + """ + methods: List[PyMethod] = [] + functions: List[PyMethod] = [] + method_signatures: dict[str, List[int]] = {} + # Get the methods declared under class + all_class_details: List[PyClass] = self.get_all_classes(module=module) + # Filter all method nodes + all_method_nodes: Captures = self.__get_method_nodes(module=module) + for class_details in all_class_details: + for method in class_details.methods: + method_signatures[method.full_signature] = [method.start_line, method.end_line] + methods.extend(class_details.methods) + for method_node in all_method_nodes: + method_details = self.__get_function_details(node=method_node.node) + if method_details.full_signature not in method_signatures: + functions.append(method_details) + elif method_signatures[method_details.full_signature][0] != method_details.start_line \ + and method_signatures[method_details.full_signature][1] != method_details.end_line: + functions.append(method_details) + return functions + + def get_method_details(self, module: str, method_signature: str) -> PyMethod: + """ + Given the code body and the method signature, returns the method details related to that method + Parameters + ---------- + module: code body + method_signature: method signature + + Returns + ------- + PyMethod: Returns the method details related to that method + """ + all_method_details = self.get_all_methods(module=module) + for method in all_method_details: + if method.full_signature == method_signature: + return method + return None + + def get_all_imports(self, module: str) -> List[str]: + """ + Given the code body, returns the imports in that module + Parameters + ---------- + module: code body + + Returns + ------- + List[str]: List of imports + """ + import_list = [] + captures_from_import: Captures = self.__frame_query_and_capture_output("(((import_from_statement) @imports))", + module) + captures_import: Captures = self.__frame_query_and_capture_output("(((import_statement) @imports))", module) + for capture in captures_import: + import_list.append(capture.node.text.decode()) + for capture in captures_from_import: + import_list.append(capture.node.text.decode()) + return import_list + + def get_module_details(self, module: str) -> PyModule: + return PyModule(functions=self.get_all_functions(module=module), + classes=self.get_all_classes(module=module), + imports=self.get_all_imports_details(module=module), + qualified_name='') + + def get_all_imports_details(self, module: str) -> List[PyImport]: + """ + Given the code body, returns the imports in that module + Parameters + ---------- + module: code body + + Returns + ------- + List[PyImport]: List of imports + """ + import_list = [] + captures_from_import: Captures = self.__frame_query_and_capture_output("(((import_from_statement) @imports))", + module) + captures_import: Captures = self.__frame_query_and_capture_output("(((import_statement) @imports))", module) + for capture in captures_import: + imports = [] + for import_name in capture.node.children: + if import_name.type == "dotted_name": + imports.append(import_name.text.decode()) + if import_name.type == "wildcard_import": + imports.append("ALL") + import_list.append(PyImport(from_statement="", imports=imports)) + for capture in captures_from_import: + imports = [] + for i in range(2, capture.node.child_count): + if capture.node.children[i].type == "dotted_name": + imports.append(capture.node.children[i].text.decode()) + if capture.node.children[i].type == "wildcard_import": + imports.append("ALL") + import_list.append(PyImport(from_statement=capture.node.children[1].text.decode(), imports=imports)) + return import_list + + def get_all_fields(self, module: str) -> List[PyField]: + pass + + def get_all_classes(self, module: str) -> List[PyClass]: + """ + Given the code body of the module, returns details of all classes in it + Parameters + ---------- + module: code body + + Returns + ------- + List[PyClass]: returns details of all classes in it + """ + classes: List[PyClass] = [] + all_class_details: Captures = self.__frame_query_and_capture_output("(((class_definition) @class_name))", + module) + for class_name in all_class_details: + code_body = class_name.node.text.decode() + class_full_signature = "" # TODO: what to fill here + klass_name = class_name.node.children[1].text.decode() + methods: List[PyMethod] = [] + super_classes: List[str] = [] + is_test_class = False + for child in class_name.node.children: + if child.type == "argument_list": + for arg in child.children: + if 'unittest' in arg.text.decode() or 'TestCase' in arg.text.decode(): + is_test_class = True + super_classes.append(arg.text.decode()) + if child.type == "block": + for block in child.children: + if block.type == "function_definition": + method = self.__get_function_details(node=block, klass_name=klass_name) + methods.append(method) + if block.type == "decorated_definition": + for decorated_def in block.children: + if decorated_def.type == "function_definition": + method = self.__get_function_details(node=decorated_def, klass_name=klass_name) + methods.append(method) + classes.append(PyClass(code_body=code_body, + full_signature=class_full_signature, + methods=methods, + super_classes=super_classes, + is_test_class=is_test_class)) + return classes + + def get_all_modules(self, application_dir: Path) -> List[PyModule]: + """ + Given an application directory, returns a list of modules + Parameters + ---------- + application_dir (Path): Location of the application directory + + Returns + ------- + List[PyModule]: returns a list of modules + """ + modules: List[PyModule] = [] + path_list = [os.path.join(dirpath, filename) for dirpath, _, filenames in os.walk(application_dir) for filename in filenames + if filename.endswith('.py')] + for p in path_list: + modules.append(self.__get_module(path=p)) + return modules + + def __get_module(self, path: Path): + module_qualified_path = os.path.join(path) + module_qualified_name = str(module_qualified_path).replace(os.sep, '.') + with open(module_qualified_path, 'r') as file: + py_module = self.get_module_details(module=file.read()) + qualified_name: str + methods: List[PyMethod] + functions: List[PyMethod] + classes: List[PyClass] + imports: List[PyImport] + return PyModule(qualified_name=module_qualified_name, + imports=py_module.imports, + functions=py_module.functions, + classes=py_module.classes) + return None + + @staticmethod + def __get_call_site_details(call_node: Node) -> PyCallSite: + """ + Get details about the call site information given a call node + Parameters + ---------- + call_node + + Returns + ------- + PyCallSite: Call site information + """ + start_line = call_node.start_point[0] + start_column = call_node.start_point[1] + end_line = call_node.end_point[0] + end_column = call_node.end_point[1] + try: + method_name = call_node.children[0].children[2].text.decode() + declaring_object = call_node.children[0].children[0].text.decode() + arguments: List[str] = [] + for arg in call_node.children[1].children: + if arg.type not in ['(', ')', ',']: + arguments.append(arg.text.decode()) + except Exception: + method_name = '' + declaring_object = '' + arguments = [] + return PyCallSite(method_name=method_name, + declaring_object=declaring_object, + arguments=arguments, + start_line=start_line, + start_column=start_column, + end_line=end_line, + end_column=end_column) + + def __get_function_details(self, node: Node, klass_name: str = "") -> PyMethod: + code_body: str = node.text.decode() + start_line: int = node.start_point[0] + end_line: int = node.end_point[0] + method_name = "" + modifier: str = "" + formal_params: List[PyArg] = [] + return_type: str = "" + class_signature: str = klass_name + is_constructor = False + is_static = False + call_sites: List[PyCallSite] = [] + call_nodes: Captures = self.__frame_query_and_capture_output("(((call) @call_name))", node.text.decode()) + for call_node in call_nodes: + call_sites.append(self.__get_call_site_details(call_node.node)) + for function_detail in node.children: + try: + annotation_node = self.__safe_ascend(function_detail, 2) + except Exception: + annotation_node = None + if annotation_node is not None: + if annotation_node.type == "decorated_definition": + for child in annotation_node.children: + if child.type == "decorator": + if "staticmethod" in child.text.decode(): + is_static = True + if function_detail.type == "identifier": + method_name = function_detail.text.decode() + if "__init__" in method_name: + is_constructor = True + if method_name.startswith("__"): + modifier = "private" + elif method_name.startswith("_"): + modifier = "protected" + else: + modifier = "public" + if function_detail.type == "return_type": + return_type = function_detail.text.decode() + if function_detail.type == "parameters": + parameters = function_detail.text.decode() + + for parameter in function_detail.children: + formal_param: PyArg = None + if parameter.type == "identifier": + formal_param = PyArg(arg_name=parameter.text.decode(), arg_type="") + elif parameter.type == "typed_parameter": + formal_param = PyArg(arg_name=parameter.children[0].text.decode(), + arg_type=parameter.children[2].text.decode()) + elif parameter.type == "dictionary_splat_pattern": + formal_param = PyArg(arg_name=parameter.text.decode(), arg_type="") + if formal_param is not None: + formal_params.append(formal_param) + num_params = len(formal_params) + full_signature = method_name + parameters + function: PyMethod = PyMethod( + method_name=method_name, + code_body=code_body, + full_signature=full_signature, + num_params=num_params, + modifier=modifier, + formal_params=formal_params, + return_type=return_type, + class_signature=class_signature, + start_line=start_line, + end_line=end_line, + is_static=is_static, + is_constructor=is_constructor, + call_sites=call_sites + ) + return function + + + + def __get_class_nodes(self, module: str) -> Captures: + captures: Captures = self.__frame_query_and_capture_output("(((class_definition) @class_name))", module) + return captures + + def __get_method_nodes(self, module: str) -> Captures: + captures: Captures = self.__frame_query_and_capture_output("(((function_definition) @function_name))", module) + return captures diff --git a/cldk/analysis/symbol_table.py b/cldk/analysis/symbol_table.py new file mode 100644 index 0000000..d26e9cf --- /dev/null +++ b/cldk/analysis/symbol_table.py @@ -0,0 +1,84 @@ +from abc import ABC, abstractmethod + + +class SymbolTable(ABC): + def __init__(self) -> None: + super().__init__() + + ''' + Language agnostic functions + ''' + + @abstractmethod + def get_methods(self, **kwargs): + """ + Given an application or a source code, get all the methods + """ + pass + + @abstractmethod + def get_imports(self, **kwargs): + """ + Given an application or a source code, get all the imports + """ + pass + + @abstractmethod + def get_variables(self, **kwargs): + """ + Given an application or a source code, get all the variables + """ + pass + + ''' + OOP-specific functions + ''' + + @abstractmethod + def get_classes(self, **kwargs): + """ + Given an application or a source code, get all the classes + """ + pass + + @abstractmethod + def get_classes_by_criteria(self, **kwargs): + """ + Given an application or a source code, get all the classes given the inclusion and exclution criteria + """ + pass + + @abstractmethod + def get_sub_classes(self, **kwargs): + """ + Given an application or a source code, get all the sub-classes + """ + pass + + @abstractmethod + def get_nested_classes(self, **kwargs): + """ + Given an application or a source code, get all the nested classes + """ + pass + + @abstractmethod + def get_constructors(self, **kwargs): + """ + Given an application or a source code, get all the constructors + """ + pass + + @abstractmethod + def get_methods_in_class(self, **kwargs): + """ + Given an application or a source code, get all the methods within the given class + """ + pass + + @abstractmethod + def get_fields(self, **kwargs): + """ + Given an application or a source code, get all the fields + """ + pass diff --git a/cldk/analysis/system_dependence_graph.py b/cldk/analysis/system_dependence_graph.py new file mode 100644 index 0000000..986194c --- /dev/null +++ b/cldk/analysis/system_dependence_graph.py @@ -0,0 +1,6 @@ +from abc import abstractmethod, ABC + + +class SystemDependenceGraph(ABC): + def __init__(self) -> None: + super().__init__() \ No newline at end of file diff --git a/cldk/core.py b/cldk/core.py new file mode 100644 index 0000000..e7aea42 --- /dev/null +++ b/cldk/core.py @@ -0,0 +1,136 @@ +from pathlib import Path + + +import logging + +from cldk.analysis.java import JavaAnalysis +from cldk.analysis.java.treesitter import JavaSitter +from cldk.utils.exceptions import CldkInitializationException +from cldk.utils.sanitization.java.TreesitterSanitizer import TreesitterSanitizer + +logger = logging.getLogger(__name__) + + +class CLDK: + """ + The CLDK base class. + + Parameters + ---------- + language : str + The programming language of the project. + + Attributes + ---------- + language : str + The programming language of the project. + """ + + def __init__(self, language: str): + self.language: str = language + + def analysis( + self, + project_path: str | Path | None = None, + source_code: str | None = None, + eager: bool = False, + analysis_backend: str | None = "codeanalyzer", + analysis_level: str = "symbol_table", + analysis_backend_path: str | None = None, + analysis_json_path: str | Path | None = None, + use_graalvm_binary: bool = False, + ) -> JavaAnalysis: + """ + Initialize the preprocessor based on the specified language and analysis_backend. + + Parameters + ---------- + project_path : str or Path + The directory path of the project. + source_code : str, optional + The source code of the project, defaults to None. If None, it is assumed that the whole project is being + analyzed. + analysis_backend : str, optional + The analysis_backend used for analysis, defaults to "codeql". + analysis_backend_path : str, optional + The path to the analysis_backend, defaults to None and in the case of codeql, it is assumed that the cli is + installed and available in the PATH. In the case of codeanalyzer the codeanalyzer.jar is downloaded from the + lastest release. + analysis_json_path : str or Path, optional + The path save the to the analysis database (analysis.json), defaults to None. If None, the analysis database + is not persisted. + use_graalvm_binary : bool, optional + A flag indicating whether to use the GraalVM binary for SDG analysis, defaults to False. If False, + the default Java binary is used and one needs to have Java 17 or higher installed. + eager : bool, optional + A flag indicating whether to perform eager analysis, defaults to False. If True, the analysis is performed + eagerly. That is, the analysis.json file is created during analysis every time even if it already exists. + + Returns + ------- + JavaAnalysis + The initialized JavaAnalysis object. + + Raises + ------ + CldkInitializationException + If neither project_path nor source_code is provided. + NotImplementedError + If the specified language is not implemented yet. + """ + + if project_path is None and source_code is None: + raise CldkInitializationException("Either project_path or source_code must be provided.") + + if project_path is not None and source_code is not None: + raise CldkInitializationException("Both project_path and source_code are provided. Please provide " "only one.") + + if self.language == "java": + return JavaAnalysis( + project_dir=project_path, + source_code=source_code, + analysis_backend=analysis_backend, + analysis_level=analysis_level, + analysis_backend_path=analysis_backend_path, + analysis_json_path=analysis_json_path, + use_graalvm_binary=use_graalvm_binary, + eager_analysis=eager, + ) + else: + raise NotImplementedError(f"Analysis support for {self.language} is not implemented yet.") + + def treesitter_parser(self): + """ + Parse the project using treesitter. + + Returns + ------- + List + A list of treesitter nodes. + + """ + if self.language == "java": + return JavaSitter() + else: + raise NotImplementedError(f"Treesitter parser for {self.language} is not implemented yet.") + + def tree_sitter_utils(self, source_code: str): + """ + Parse the project using treesitter. + + Parameters + ---------- + source_code : str, optional + The source code of the project, defaults to None. If None, it is assumed that the whole project is being + analyzed. + + Returns + ------- + List + A list of treesitter nodes. + + """ + if self.language == "java": + return TreesitterSanitizer(source_code=source_code) + else: + raise NotImplementedError(f"Treesitter parser for {self.language} is not implemented yet.") diff --git a/cldk/models/__init__.py b/cldk/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/models/c/__init__.py b/cldk/models/c/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/models/c/models.py b/cldk/models/c/models.py new file mode 100644 index 0000000..10f3d82 --- /dev/null +++ b/cldk/models/c/models.py @@ -0,0 +1,111 @@ +from typing import List, Optional +from pydantic import BaseModel + + +class COutput(BaseModel): + """ + Represents the output of a C function. + + Parameters + ---------- + type : str + The type of the output. + qualifiers : List[str] + The list of type qualifiers. E.g.: const, volatile, restrict, etc. + is_reference : bool + A flag indicating whether the output is a pointer. + """ + + type: str + # Type qualifiers: const, volatile, restrict, etc. + qualifiers: List[str] + is_reference: bool + + +class CParameter(COutput): + """ + Represents a parameter in a C function. + + Parameters + ---------- + name : str + The name of the parameter. + specifiers : List[str] + The list of storage class specifiers. E.g.: auto, register, static, extern, etc. + """ + + name: str + # Storage-class specifiers: auto, register, static, extern, etc. + specifiers: List[str] + + +class CFunction(BaseModel): + """ + Represents a function in C. + + Parameters + ---------- + name : str + The name of the function. + signature : str + The signature of the function. + comment : Optional[str] + The comment associated with the function. + parameters : List[CParameter] + The parameters of the function. + output : COutput + The return of the function. + code : str + The code block of the callable. + start_line : int + The starting line number of the callable in the source file. + end_line : int + The ending line number of the callable in the source file. + specifiers : List[str] + The list of storage class specifiers. E.g.: auto, register, static, extern, etc. + """ + + name: str + signature: str + code: str + start_line: int + end_line: int + parameters: List[CParameter] + output: COutput + comment: Optional[str] + # Storage-class specifiers: auto, register, static, extern, etc. + specifiers: List[str] + + +class CImport(BaseModel): + """ + Represents a C import. + + Parameters + ---------- + value : str + The name or path of file being imported. + is_system : bool + A flag indicating whether the import is a system one. + """ + + value: str + is_system: bool + + +class CTranslationUnit(BaseModel): + """ + Represents the content of a C file. + + Parameters + ---------- + imports : List[CImport] + The list of imports present inside the file. + functions : List[CFunction] + The functions defined inside the file. + """ + + imports: List[CImport] + functions: List[CFunction] + + # TODO: type definitions, structs diff --git a/cldk/models/go/__init__.py b/cldk/models/go/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/models/go/models.py b/cldk/models/go/models.py new file mode 100644 index 0000000..c741cb4 --- /dev/null +++ b/cldk/models/go/models.py @@ -0,0 +1,101 @@ +from typing import List, Optional +from pydantic import BaseModel + + +class GoParameter(BaseModel): + """ + Represents a function/method parameter in Go. + + Parameters + ---------- + name : Optional[str] + The name of the parameter. If the parameter is not referenced, the name can be omitted. + type : str + The type of the parameter. + is_reference : bool + A flag indicating whether the type is a pointer. + is_variadic : bool + A flag indicating whether the parameter is variadic. + """ + + # Go allows parameters without name, they can't be used + name: Optional[str] = None + type: str + is_reference: bool + is_variadic: bool + + +class GoImport(BaseModel): + """ + Represents an import in Go. + + Parameters + ---------- + path : str + The path to the imported package. + name : Optional[str] + The alias of the imported package. Useful when the implicit name is + too long or to avoid package name clashes. + """ + + path: str + name: Optional[str] = None + + +class GoCallable(BaseModel): + """ + Represents a callable entity such as a method or function in Go. + + Parameters + ---------- + name : str + The name of the callable. + signature : str + The signature of the callable. + comment : str + The comment associated with the callable. + modifiers : List[str] + The modifiers applied to the callable (e.g., public, static). + parameters : List[GoParameter] + The parameters of the callable. + return_types : List[str] + The list of return type of the callable. Empty list, if the callable does not return a value. + receiver : Optional[GoParameter] + The callable's associated type. Only applicable for methods. + code : str + The code block of the callable. + start_line : int + The starting line number of the callable in the source file. + end_line : int + The ending line number of the callable in the source file. + """ + + name: str + signature: str + comment: str + modifiers: List[str] + parameters: List[GoParameter] + # only methods have a receiver + receiver: Optional[GoParameter] = None + return_types: List[str] + code: str + start_line: int + end_line: int + + +class GoSourceFile(BaseModel): + """ + Represents a source file in Go. + + Parameters + ---------- + imports : List[GoImport] + The list of imports present in the source file. + callables : List[GoCallable] + The list of callable entities present in the source file. + """ + + imports: List[GoImport] + callables: List[GoCallable] + + # TODO: types diff --git a/cldk/models/java/__init__.py b/cldk/models/java/__init__.py new file mode 100644 index 0000000..9904d84 --- /dev/null +++ b/cldk/models/java/__init__.py @@ -0,0 +1,18 @@ +from .models import ( + JApplication, + JCallable, + JType, + JCompilationUnit, + JGraphEdges, +) + +from .constants_namespace import ConstantsNamespace + +__all__ = [ + JApplication, + JCallable, + JType, + JCompilationUnit, + JGraphEdges, + ConstantsNamespace +] diff --git a/cldk/models/java/constants_namespace.py b/cldk/models/java/constants_namespace.py new file mode 100644 index 0000000..1f96afc --- /dev/null +++ b/cldk/models/java/constants_namespace.py @@ -0,0 +1,16 @@ +class ConstantsNamespace: + @property + def ENTRY_POINT_SERVLET_CLASSES(self): + return ("javax.servlet.GenericServlet","javax.servlet.Filter","javax.servlet.http.HttpServlet") + @property + def ENTRY_POINT_METHOD_SERVLET_PARAM_TYPES(self): + return ("javax.servlet.ServletRequest","javax.servlet.ServletResponse","javax.servlet.http.HttpServletRequest","javax.servlet.http.HttpServletResponse") + @property + def ENTRY_POINT_METHOD_JAVAX_WS_ANNOTATIONS(self): + return ("javax.ws.rs.POST", "javax.ws.rs.PUT", "javax.ws.rs.GET", "javax.ws.rs.HEAD", "javax.ws.rs.DELETE") + @property + def ENTRY_POINT_CLASS_SPRING_ANNOTATIONS(self): + return ("Controller","RestController") + @property + def ENTRY_POINT_METHOD_SPRING_ANNOTATIONS(self): + return ("GetMapping","PathMapping","PostMapping","PutMapping","RequestMapping","DeleteMapping") \ No newline at end of file diff --git a/cldk/models/java/models.py b/cldk/models/java/models.py new file mode 100644 index 0000000..9463551 --- /dev/null +++ b/cldk/models/java/models.py @@ -0,0 +1,392 @@ +import json +import re +from typing import Dict, List, Optional +from pydantic import BaseModel, field_validator, model_validator +from contextvars import ContextVar +from .constants_namespace import ConstantsNamespace + +constants = ConstantsNamespace() +context_concrete_class = ContextVar("context_concrete_class") # context var to store class concreteness + + +class JField(BaseModel): + """ + Represents a field in a Java class or interface. + + Parameters + ---------- + comment : str + The comment associated with the field. + name : str + The name of the field. + type : str + The type of the field. + start_line : int + The starting line number of the field in the source file. + end_line : int + The ending line number of the field in the source file. + variables : List[str] + The variables declared in the field. + modifiers : List[str] + The modifiers applied to the field (e.g., public, static). + annotations : List[str] + The annotations applied to the field. + """ + + comment: str + type: str + start_line: int + end_line: int + variables: List[str] + modifiers: List[str] + annotations: List[str] + + +class JCallableParameter(BaseModel): + """ + Represents a parameter of a Java callable. + + Parameters + ---------- + name : str + The name of the parameter. + type : str + The type of the parameter. + """ + + name: str + type: str + annotations: List[str] + modifiers: List[str] + + +class JEnumConstant(BaseModel): + name: str + arguments: List[str] + + +class JCallSite(BaseModel): + """ + Represents a call site. + + Parameters + ---------- + method_name : str + The name of the method called at the call site. + receiver_expr : str + Expression for the receiver of the method call. + receiver_type : str + Name of type declaring the called method. + argument_types : List[str] + Types of actual parameters for the call. + return_type : str + Return type of the method call (resolved type of the method call expression; empty string if expression is + unresolved). + is_static_call: bool + Flag indicating whether the call is a static call + is_constructor_call: bool + Flag indicating whether the call is a constructor call + start_line : int + The starting line number of the call site. + start_column : int + The starting column of the call site. + end_line : int + The ending line number of the call site. + end_column : int + The ending column of the call site. + """ + + method_name: str + receiver_expr: str = "" + receiver_type: str + argument_types: List[str] + return_type: str = "" + is_static_call: bool + is_private: bool + is_public: bool + is_protected: bool + is_unspecified: bool + is_constructor_call: bool + start_line: int + start_column: int + end_line: int + end_column: int + + +class JVariableDeclaration(BaseModel): + """ + Represents a variable declaration. + + Parameters + ---------- + name : str + The name of the variable. + type : str + The type of the variable. + initializer : str + The initialization expression (if persent) for the variable declaration. + start_line : int + The starting line number of the declaration. + start_column : int + The starting column of the declaration. + end_line : int + The ending line number of the declaration. + end_column : int + The ending column of the declaration. + """ + + name: str + type: str + initializer: str + start_line: int + start_column: int + end_line: int + end_column: int + + +class JCallable(BaseModel): + """ + Represents a callable entity such as a method or constructor in Java. + + Parameters + ---------- + signature : str + The signature of the callable. + is_implicit : bool + A flag indicating whether the callable is implicit (e.g., a default constructor). + is_constructor : bool + A flag indicating whether the callable is a constructor. + comment : str + The comment associated with the callable. + annotations : List[str] + The annotations applied to the callable. + modifiers : List[str] + The modifiers applied to the callable (e.g., public, static). + thrown_exceptions : List[str] + Exceptions declared via "throws". + declaration : str + The declaration of the callable. + parameters : List[ParameterInCallable] + The parameters of the callable. + return_type : Optional[str] + The return type of the callable. None if the callable does not return a value (e.g., a constructor). + code : str + The code block of the callable. + start_line : int + The starting line number of the callable in the source file. + end_line : int + The ending line number of the callable in the source file. + referenced_types : List[str] + The types referenced within the callable. + accessed_fields : List[str] + Fields accessed in the callable. + call_sites : List[JCallSite] + Call sites in the callable. + variable_declarations : List[JVariableDeclaration] + Local variable declarations in the callable. + cyclomatic_complexity : int + Cyclomatic complexity of the callable. + """ + + signature: str + is_implicit: bool + is_constructor: bool + is_entry_point: bool = False + comment: str + annotations: List[str] + modifiers: List[str] + thrown_exceptions: List[str] = [] + declaration: str + parameters: List[JCallableParameter] + return_type: Optional[str] = None # Pythonic way to denote a nullable field + code: str + start_line: int + end_line: int + referenced_types: List[str] + accessed_fields: List[str] + call_sites: List[JCallSite] + variable_declarations: List[JVariableDeclaration] + cyclomatic_complexity: int | None + + def __hash__(self): + return hash(self.declaration) + + @model_validator(mode="after") + def detect_entrypoint_method(self): + # check first if the class in which this method exists is concrete or not, by looking at the context var + if context_concrete_class.get(): + # convert annotations to the form GET, POST even if they are @GET or @GET('/ID') etc. + annotations_cleaned = [match for annotation in self.annotations for match in + re.findall(r'@(.*?)(?:\(|$)', annotation)] + + param_type_list = [val.type for val in self.parameters] + # check the param types against known servlet param types + if any(substring in string for substring in param_type_list for string in + constants.ENTRY_POINT_METHOD_SERVLET_PARAM_TYPES): + # check if this method is over-riding (only methods that override doGet / doPost etc. will be flagged as first level entry points) + if 'Override' in annotations_cleaned: + self.is_entry_point = True + return self + + # now check the cleaned annotations against known javax ws annotations + if any(substring in string for substring in annotations_cleaned for string in + constants.ENTRY_POINT_METHOD_JAVAX_WS_ANNOTATIONS): + self.is_entry_point = True + return self + + # check the cleaned annotations against known spring rest method annotations + if any(substring in string for substring in annotations_cleaned for string in + constants.ENTRY_POINT_METHOD_SPRING_ANNOTATIONS): + self.is_entry_point = True + return self + return self + + +class JType(BaseModel): + """ + Represents a Java class or interface. + + Parameters + ---------- + name : str + The name of the class or interface. + is_interface : bool + A flag indicating whether the object is an interface. + is_inner_class : bool + A flag indicating whether the object is an inner class. + is_local_class : bool + A flag indicating whether the object is a local class. + is_nested_type: bool + A flag indicating whether the object is a nested type. + comment : str + The comment of the class or interface. + extends_list : List[str] + The list of classes or interfaces that the object extends. + implements_list : List[str] + The list of interfaces that the object implements. + modifiers : List[str] + The list of modifiers of the object. + annotations : List[str] + The list of annotations of the object. + parent_class: + The name of the parent class (if it exists) + nested_class_declerations: List[str] + All the class declerations nested under this class. + constructor_declarations : List[JCallable] + The list of constructors of the object. + method_declarations : List[JCallable] + The list of methods of the object. + field_declarations : List[JField] + The list of fields of the object. + """ + + is_interface: bool = False + is_inner_class: bool = False + is_local_class: bool = False + is_nested_type: bool = False + is_class_or_interface_declaration: bool = False + is_enum_declaration: bool = False + is_annotation_declaration: bool = False + is_record_declaration: bool = False + is_concrete_class: bool = False + is_entry_point: bool = False + comment: str + extends_list: List[str] = [] + implements_list: List[str] = [] + modifiers: List[str] = [] + annotations: List[str] = [] + parent_type: str + nested_type_declerations: List[str] = [] + callable_declarations: Dict[str, JCallable] = {} + field_declarations: List[JField] = [] + enum_JavaEEEntryPoints: List[JEnumConstant] = [] + + # first get the data in raw form and check if the class is concrete or not, before any model validation is done + # for this we assume if a class is not an interface or abstract it is concrete + # for abstract classes we will check the modifiers + @model_validator(mode="before") + def check_concrete_class(cls, values): + values["is_concrete_class"] = False + if values.get("is_class_or_interface_declaration") and not values.get("is_interface"): + if "abstract" not in values.get("modifiers"): + values["is_concrete_class"] = True + # since the methods in this class need access to the concrete class flag, + # we will store this in a context var - this is a hack + token = context_concrete_class.set(values["is_concrete_class"]) + return values + + # after model validation is done we populate the is_entry_point flag by checking + # if the class extends or implements known servlet classes + @model_validator(mode="after") + def check_concrete_entry_point(self): + if self.is_concrete_class: + if any(substring in string for substring in (self.extends_list + self.implements_list) + for string in constants.ENTRY_POINT_SERVLET_CLASSES): + self.is_entry_point = True + return self + # Handle spring classes + # clean annotations - take out @ and any paranehesis along with info in them. + annotations_cleaned = [match for annotation in self.annotations for match in + re.findall(r'@(.*?)(?:\(|$)', annotation)] + if any(substring in string for substring in annotations_cleaned + for string in constants.ENTRY_POINT_CLASS_SPRING_ANNOTATIONS): + self.is_entry_point = True + return self + # context_concrete.reset() + return self + + +class JCompilationUnit(BaseModel): + comment: str + imports: List[str] + type_declarations: Dict[str, JType] + + +class JMethodDetail(BaseModel): + method_declaration: str + # class is a reserved keyword in python. we'll use klass. + klass: str + method: JCallable + + def __repr__(self): + return f"JMethodDetail({self.method_declaration})" + + def __hash__(self): + return hash(tuple(self)) + + +class JGraphEdges(BaseModel): + source: JMethodDetail + target: JMethodDetail + type: str + weight: str + source_kind: str | None = None + destination_kind: str | None = None + + @field_validator("source", "target", mode="before") + @classmethod + def validate_source(cls, value) -> JMethodDetail: + callable_dict = json.loads(value) + j_callable = JCallable(**json.loads(callable_dict["callable"])) # parse the value which is a quoted string + class_name = callable_dict["class_interface_declarations"] + method_decl = j_callable.declaration + mc = JMethodDetail(method_declaration=method_decl, klass=class_name, method=j_callable) + return mc + + def __hash__(self): + return hash(tuple(self)) + + +class JApplication(BaseModel): + """ + Represents a Java application. + + Parameters + ---------- + symbol_table : List[JCompilationUnit] + The symbol table representation + system_dependency : List[JGraphEdges] + The edges of the system dependency graph. Default None. + """ + symbol_table: Dict[str, JCompilationUnit] + system_dependency_graph: List[JGraphEdges] = None diff --git a/cldk/models/javascript/__init__.py b/cldk/models/javascript/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/models/javascript/models.py b/cldk/models/javascript/models.py new file mode 100644 index 0000000..fe8e75a --- /dev/null +++ b/cldk/models/javascript/models.py @@ -0,0 +1,95 @@ +from typing import List, Optional +from pydantic import BaseModel + + +class JsParameter(BaseModel): + """ + Represents a function/method parameter in Javascript. + + Parameters + ---------- + name : str + The name of the parameter. + default_value : Optional[str] + The default value of the parameter, used when a value is not provided. + is_rest : bool + A flag indicating whether the parameter is a rest parameter. + """ + + name: str + default_value: Optional[str] = None + is_rest: bool = False + # type: Optional[str] - might be able to extract from JsDoc + + +class JsCallable(BaseModel): + """ + Represents a callable entity such as a method or function in Javascript. + + Parameters + ---------- + name : str + The name of the callable. + signature : str + The signature of the callable. + parameters : List[JsParameter] + The parameters of the callable. + code : str + The code block of the callable. + start_line : int + The starting line number of the callable in the source file. + end_line : int + The ending line number of the callable in the source file. + is_constructor : bool + A flag indicating whether the callable is a constructor. + """ + + name: str + code: str + signature: str + paremeters: List[JsParameter] + start_line: int + end_line: int + is_constructor: bool = False + + +class JsClass(BaseModel): + """ + Represents a class in Javascript. + + Parameters + ---------- + name : str + The name of the class. + methods : List[JsCallable] + The methods of the class. + start_line : int + The starting line number of the class in the source file. + end_line : int + The ending line number of the class in the source file. + parent : Optional[str] + The name of the parent class. + """ + + name: str + methods: List[JsCallable] + parent: Optional[str] = None + start_line: int + end_line: int + + +class JsProgram(BaseModel): + """ + Represents a source file in Javascript. + + Parameters + ---------- + classes : List[JsClass] + The list of classes present in the source file. + callables : List[JsCallable] + The list of callable entities present in the source file. + """ + + classes: List[JsClass] + callables: List[JsCallable] + # TODO: imports diff --git a/cldk/models/python/__init__.py b/cldk/models/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/models/python/models.py b/cldk/models/python/models.py new file mode 100644 index 0000000..ab08b94 --- /dev/null +++ b/cldk/models/python/models.py @@ -0,0 +1,64 @@ +from typing import List +from pydantic import BaseModel + + +class PyArg(BaseModel): + arg_name: str + arg_type: str + + +class PyImport(BaseModel): + from_statement: str + imports: List[str] + + +class PyCallSite(BaseModel): + method_name: str + declaring_object: str + arguments: List[str] + start_line: int + start_column: int + end_line: int + end_column: int + + +class PyMethod(BaseModel): + code_body: str + method_name: str + full_signature: str + num_params: int + modifier: str + is_constructor: bool + is_static: bool + formal_params: List[PyArg] + call_sites: List[PyCallSite] + return_type: str + class_signature: str + start_line: int + end_line: int + # incoming_calls: Optional[List["PyMethod"]] = None + # outgoing_calls: Optional[List["PyMethod"]] = None + + +class PyClass(BaseModel): + code_body: str + full_signature: str + super_classes: List[str] + is_test_class: bool + methods: List[PyMethod] + + +class PyModule(BaseModel): + qualified_name: str + functions: List[PyMethod] + classes: List[PyClass] + imports: List[PyImport] + #expressions: str + + +class PyBuildAttributes(BaseModel): + """Handles all the project build tool (requirements.txt/poetry/setup.py) attributes""" + + +class PyConfig(BaseModel): + """Application configuration information""" diff --git a/cldk/models/treesitter/__init__.py b/cldk/models/treesitter/__init__.py new file mode 100644 index 0000000..50dae79 --- /dev/null +++ b/cldk/models/treesitter/__init__.py @@ -0,0 +1,3 @@ +from .models import Captures + +__all__ = ["Captures"] diff --git a/cldk/models/treesitter/models.py b/cldk/models/treesitter/models.py new file mode 100644 index 0000000..2ef865a --- /dev/null +++ b/cldk/models/treesitter/models.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import List, Tuple + +from tree_sitter import Node + + +@dataclass +class Captures: + """This class is a dataclass that represents the captures from a tree-sitter query. + Attributes + ---------- + captures : List[Capture] + A list of captures from the tree-sitter query. + """ + + @dataclass + class Capture: + """This class is a dataclass that represents a single capture from a tree-sitter query. + Attributes + ---------- + node : Node + The node that was captured. + name : str + The name of the capture. + """ + + node: Node + name: str + + def __init__(self, captures: List[Tuple[Node, str]]): + self.captures = [self.Capture(node=node, name=text) for node, text in captures] + + def __getitem__(self, index: int) -> Capture: + """Get the capture at the specified index. + Parameters: + ----------- + index : int + The index of the capture to get. + Returns + ------- + Capture + The capture at the specified index. + """ + return self.captures[index] + + def __iter__(self): + """Return an iterator over the captures.""" + return iter(self.captures) + + def __len__(self) -> int: + """Return the number of captures.""" + return len(self.captures) diff --git a/cldk/utils/__init__.py b/cldk/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/utils/analysis_engine.py b/cldk/utils/analysis_engine.py new file mode 100644 index 0000000..ab169e1 --- /dev/null +++ b/cldk/utils/analysis_engine.py @@ -0,0 +1,4 @@ +class AnalysisEngine: + TREESITTER: str = "treesitter" + CODEQL: str = "codeql" + CODEANALYZER: str = "codeanalyzer" \ No newline at end of file diff --git a/cldk/utils/exceptions/__init__.py b/cldk/utils/exceptions/__init__.py new file mode 100644 index 0000000..b48e77e --- /dev/null +++ b/cldk/utils/exceptions/__init__.py @@ -0,0 +1,13 @@ +from .exceptions import ( + CldkInitializationException, + CodeanalyzerExecutionException, + CodeQLDatabaseBuildException, + CodeQLQueryExecutionException, +) + +__all__ = [ + "CodeQLDatabaseBuildException", + "CodeQLQueryExecutionException", + "CodeanalyzerExecutionException", + "CldkInitializationException", +] diff --git a/cldk/utils/exceptions/exceptions.py b/cldk/utils/exceptions/exceptions.py new file mode 100644 index 0000000..8646228 --- /dev/null +++ b/cldk/utils/exceptions/exceptions.py @@ -0,0 +1,40 @@ +class CldkInitializationException(Exception): + """Custom exception for errors during CLDK initialization.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class CodeanalyzerExecutionException(Exception): + """Exception raised for errors that occur during the execution of Codeanalyzer.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class CodeQLDatabaseBuildException(Exception): + """Exception raised for errors that occur during the building of a CodeQL database.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class CodeQLQueryExecutionException(Exception): + """Exception raised for errors that occur during the execution of a CodeQL query.""" + + def __init__(self, message): + self.message = message + super().__init__(self.message) + + +class CodeanalyzerUsageException(Exception): + """ + Exception is raised when the usage of codeanalyzer is incorrect. + """ + + def __init__(self, message): + self.message = message + super().__init__(self.message) diff --git a/cldk/utils/logging.py b/cldk/utils/logging.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/utils/sanitization/__init__.py b/cldk/utils/sanitization/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/utils/sanitization/java/TreesitterSanitizer.py b/cldk/utils/sanitization/java/TreesitterSanitizer.py new file mode 100644 index 0000000..c4fd41e --- /dev/null +++ b/cldk/utils/sanitization/java/TreesitterSanitizer.py @@ -0,0 +1,323 @@ +from copy import deepcopy +from typing import Dict, List, Set + +from cldk.analysis.java.treesitter import JavaSitter +from cldk.models.treesitter import Captures +import logging + +log = logging.getLogger(__name__) + + +class TreesitterSanitizer: + + def __init__(self, source_code): + self.source_code = source_code + self.sanitized_code = deepcopy(self.source_code) + self.__javasitter = JavaSitter() + + def keep_only_focal_method_and_its_callees(self, focal_method: str) -> str: + """Remove all methods except the focal method and its callees. + + Parameters + ---------- + focal_method : str + The of the focal method. + source_code : str + The source code to process. + + Returns + ------- + str + The pruned source code. + """ + method_declaration: Captures = self.__javasitter.frame_query_and_capture_output(query="((method_declaration) " "@method_declaration)", code_to_process=self.sanitized_code) + declared_methods = {self.__javasitter.get_method_name_from_declaration(capture.node.text.decode()): capture.node.text.decode() for capture in method_declaration} + unused_methods: Dict = self._unused_methods(focal_method, declared_methods) + for _, method_body in unused_methods.items(): + self.sanitized_code = self.sanitized_code.replace(method_body, "") + return self.__javasitter.make_pruned_code_prettier(self.sanitized_code) + + def remove_unused_imports(self, sanitized_code: str) -> str: + """ + Remove all unused imports from the source code. Assuming you have removed all unused fields in a class and + you are given a pruned source code, visit every child and look for all type identifiers and other identifiers + and compare them with every element in the used imports list. + + Parameters + ---------- + source_code : str + The source code to process. + + sanitized_code : str + The source code after having performed certain operations like removing unused methods, etc. + + Returns + ------- + str + The pruned source code with all the unused imports removed. + + Steps + ----- + + To compare, split string by '.' and compare if the last element is in the set of identifiers or type identifiers. + + Create a set called ids_and_typeids to capture all the indentified identifiers and type identifiers in the source code. + + Next, create a set called unused_imports. + + If import ends_with("*"), ignore it since we'll keep all wildcard imports. + + For every other import statements, compare if the last element in the import string is in the set of ids_and_typeids. If not, add it to the unused_imports set. + + Finally, remove all the unused imports from the source code and prettify it. + """ + pruned_source_code: str = deepcopy(sanitized_code) + import_declerations: Captures = self.__javasitter.frame_query_and_capture_output(query="((import_declaration) @imports)", code_to_process=self.source_code) + + unused_imports: Set = set() + ids_and_typeids: Set = set() + class_bodies: Captures = self.__javasitter.frame_query_and_capture_output(query="((class_declaration) @class_decleration)", code_to_process=self.source_code) + for class_body in class_bodies: + all_type_identifiers_in_class: Captures = self.__javasitter.frame_query_and_capture_output( + query="((type_identifier) @type_id)", + code_to_process=class_body.node.text.decode(), + ) + all_other_identifiers_in_class: Captures = self.__javasitter.frame_query_and_capture_output( + query="((identifier) @other_id)", + code_to_process=class_body.node.text.decode(), + ) + ids_and_typeids.update({type_id.node.text.decode() for type_id in all_type_identifiers_in_class}) + ids_and_typeids.update({other_id.node.text.decode() for other_id in all_other_identifiers_in_class}) + + for import_decleration in import_declerations: + wildcard_import: Captures = self.__javasitter.frame_query_and_capture_output(query="((asterisk) @wildcard)", code_to_process=import_decleration.node.text.decode()) + if len(wildcard_import) > 0: + continue + + import_statement: Captures = self.__javasitter.frame_query_and_capture_output( + query="((scoped_identifier) @scoped_identifier)", code_to_process=import_decleration.node.text.decode() + ) + import_str = import_statement.captures[0].node.text.decode() + if not import_str.split(".")[-1] in ids_and_typeids: + unused_imports.add(import_decleration.node.text.decode()) + + for unused_import in unused_imports: + pruned_source_code = pruned_source_code.replace(unused_import, "") + + return self.__javasitter.make_pruned_code_prettier(pruned_source_code) + + def remove_unused_fields(self, sanitized_code: str) -> str: + """ + Take the pruned source code and remove all the unused fields. + + Parameters + ---------- + source_code : str + Source code after having removed all unused methods. + + Implementation details + ---------------------- + + Get all the field declarations in the file -> used_fields + + Get all the identifiers inside every declared method and constructor in the file. + + Then, loop over every used_fields, get the field name, and add it to unused if the identifier doesn't match any known identifier + in the previous step. + """ + pruned_source_code: str = deepcopy(sanitized_code) + unused_fields: List[Captures.Capture] = list() + field_declarations: Captures = self.__javasitter.frame_query_and_capture_output(query="((field_declaration) @field_declaration)", code_to_process=pruned_source_code) + method_declaration: Captures = self.__javasitter.frame_query_and_capture_output(query="((method_declaration) @method_declaration)", code_to_process=pruned_source_code) + constructor_declaration: Captures = self.__javasitter.frame_query_and_capture_output( + query="((constructor_declaration) @constructor_declaration)", code_to_process=pruned_source_code + ) + all_used_identifiers = set() + for method in method_declaration: + all_used_identifiers.update( + { + capture.node.text.decode() + for capture in self.__javasitter.frame_query_and_capture_output(query="((identifier) @identifier)", code_to_process=method.node.text.decode()) + } + ) + + for constructor in constructor_declaration: + all_used_identifiers.update( + { + capture.node.text.decode() + for capture in self.__javasitter.frame_query_and_capture_output(query="((identifier) @identifier)", code_to_process=constructor.node.text.decode()) + } + ) + + used_fields = [capture for capture in field_declarations] + + for field in used_fields: + field_identifiers = { + capture.node.text.decode() + for capture in self.__javasitter.frame_query_and_capture_output(query="((identifier) @identifier)", code_to_process=field.node.text.decode()) + } + if not field_identifiers.intersection(all_used_identifiers): + unused_fields.append(field) + + for unused_field in unused_fields: + pruned_source_code = pruned_source_code.replace(unused_field.node.text.decode(), "") + + return self.__javasitter.make_pruned_code_prettier(pruned_source_code) + + def remove_unused_classes(self, sanitized_code: str) -> str: + """ + Remove inner classes that are no longer used. + + Parameters + ---------- + sanitized_code : str + The sanitized code to process. + + Implementation steps + --------------------- + + Make a deep copy of the source code. + + Get the focal class name by getting the name of the outermost class. + + Get a dictionary of the class name and class declarations in the source code. + + Create unused_classes, a dictionary to hold all the classes that are not used. We seed it with all the classes. + + Create a to_process stack to hold the classes to process. Seed it with the focal class. + + Create a processed_so_far set to hold all the processed classes to avoid reprocessing classes repeatedly. + + While to_process is not empty, pop the first class from the to_process stack, add it to processed_so_far set, + remove the inner classes from the current class, and get all the type invocations in the current class. + + Add all the type invocations to the to_process stack iff: + 1. they are not in the processed_so_far set, and 2. They are in the all_classes dictionary (i.e., they are defined in the class). + + Loop until to_process is empty. + + Finally, remove all the unused classes from the source code and prettify it. + """ + focal_class = self.__javasitter.frame_query_and_capture_output(query="(class_declaration name: (identifier) @name)", code_to_process=self.source_code) + + try: + # We use [0] because there may be several nested classes, + # we'll consider the outermost class as the focal class. + focal_class_name = focal_class[0].node.text.decode() + except: + return "" + + pruned_source_code = deepcopy(sanitized_code) + + # Find the first class and we'll continue to operate on the inner classes. + inner_class_declarations: Captures = self.__javasitter.frame_query_and_capture_output("((class_declaration) @class_declaration)", pruned_source_code) + + # Store a dictionary of all the inner classes. + all_classes = dict() + for capture in inner_class_declarations: + inner_class = self.__javasitter.frame_query_and_capture_output(query="(class_declaration name: (identifier) @name)", code_to_process=capture.node.text.decode()) + all_classes[inner_class[0].node.text.decode()] = capture.node.text.decode() + + unused_classes: dict = deepcopy(all_classes) + + to_process = {focal_class_name} + + processed_so_far: Set = set() + + while to_process: + current_class_name = to_process.pop() + current_class_body = unused_classes.pop(current_class_name) + current_class_without_inner_class = current_class_body + processed_so_far.add(current_class_name) + + # Remove the body of inner classes from the current outer class. + inner_class_declarations: Captures = self.__javasitter.frame_query_and_capture_output("(class_body (class_declaration) @class_declaration)", current_class_body) + for capture in inner_class_declarations: + current_class_without_inner_class = current_class_without_inner_class.replace(capture.node.text.decode(), "") + + # Find all the type_references in the current class. + type_references: Set[str] = self.__javasitter.get_all_type_invocations(current_class_without_inner_class) + to_process.update({type_reference for type_reference in type_references if type_reference in all_classes and not type_reference in processed_so_far}) + + for _, unused_class_body in unused_classes.items(): + pruned_source_code = pruned_source_code.replace(unused_class_body, "") + + return self.__javasitter.make_pruned_code_prettier(pruned_source_code) + + def _unused_methods(self, focal_method: str, declared_methods: Dict) -> Dict: + """ + Parameters + ---------- + focal_method : str + The focal method that acts the starting point. If nothing, every other method execpt this one will be unused + declared_methods : dict + A dictionary of all the declared methods in the focal class + + Returns + ------- + Dict[str, str] + A dictionary of unused methods, where the key is the method name and value is the method body. + + Implementation details + ---------------------- + 1. Create a dictionary to hold all the methods to be processed (call this to_process). + 2. Make a deep copy of the declared methods dict to initialize the unused_methods. The intuition here is that + every method is unused initially, and we'll pop them out as we encounter them in the class. + 3. Start by initializing the to_process queue with the focal method. + 4. There may be recursive or cyclic calls, let's create a set called processed_so_far to hold all the processed + methods so we don't keep cycling between the methods or getting stuck at an infinite loop. + 5. In while loop, dequeue the first element from the to_process dictionary, put it processed_so_far set, and + obtain all the invoked methods in that (dequeued) method. We will assume that returned invoked methods are + only those that are declared in the class. + 6. For each of the invoked methods, and enqueue them in the to_process queue iff it hasn't been seen in + processed_so_far + 7. Loop until to_process is empty. + 6. It is now safe to assume that all the remaining methods in unused_methods dictionary really are unused and + may be removed from the file. + """ + + unused_methods = deepcopy(declared_methods) # A deep copy of unused methods. + + # A stack to hold the methods to process. + to_process = [focal_method] # Remove this element from unused methods and put it + # in the to_process stack. + + # The set below holds all processed methods bodies. This helps avoid recursive and cyclical calls. + processed_so_far: Set = set() + + while to_process: + # Remove the current method from the to process stack + current_method_name = to_process.pop() + + # This method has been processed already, so we'll skip it. + if current_method_name in processed_so_far: + continue + current_method_body = unused_methods.pop(current_method_name) + processed_so_far.add(current_method_name) + # Below, we find all method invocations that are made inside current_method_body that are also declared in + # the class. We will get back an empty set if there are no more. + all_invoked_methods = self.__javasitter.get_call_targets(current_method_body, declared_methods=declared_methods) + # Add all the methods invoked in a call to to_process iff those methods are declared in the class. + to_process.extend([invoked_method_name for invoked_method_name in all_invoked_methods if invoked_method_name not in processed_so_far]) + + assert len(unused_methods) < len(declared_methods), "At least one of the declared methods (the focal method) must have be used?" + + return unused_methods + + def sanitize_focal_class(self, focal_method: str) -> str: + """Remove all methods except the focal method and its callees. + + Given the focal method name and the entire source code, the output will be the pruned source code. + + Parameters + ---------- + focal_method : str + The name of the focal method. + source_code_file : Path + The path to the source code file. + + Returns + ------- + str + The pruned source code. + """ + + focal_method_name = self.__javasitter.get_method_name_from_declaration(focal_method) + + # Remove block comments + sanitized_code = self.__javasitter.remove_all_comments(self.sanitized_code) + + # The source code after removing + sanitized_code = self.keep_only_focal_method_and_its_callees(focal_method_name) + + # Focal method was found in the class, remove unused fields, imports, and classes. + sanitized_code = self.remove_unused_fields(sanitized_code) + + # Focal method was found in the class, remove unused fields, imports, and classes. + sanitized_code = self.remove_unused_imports(sanitized_code) + + # Focal method was found in the class, remove unused fields, imports, and classes. + sanitized_code = self.remove_unused_classes(sanitized_code) + + return sanitized_code diff --git a/cldk/utils/sanitization/java/TreesitterUtils.py b/cldk/utils/sanitization/java/TreesitterUtils.py new file mode 100644 index 0000000..0ece1d5 --- /dev/null +++ b/cldk/utils/sanitization/java/TreesitterUtils.py @@ -0,0 +1,505 @@ +import re +from copy import deepcopy +from typing import Dict, List, Any, LiteralString + +from cldk.analysis.java.treesitter import JavaSitter +from cldk.models.treesitter import Captures + +java_sitter = JavaSitter() + + +def _replace_in_source( + source_class_code: str, + original_test_method_dict: dict, + modified_test_method_dict: dict, +): + """ + Returns a modified source using original test methods and modified ones. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + original_test_method_dict: dict + Dictionary of original test methods in the java class. + + modified_test_method_dict: dict + Dictionary of modified test methods + + Returns: + -------- + str + modified source after removing duplicate test methods and merging decomposed ones. + + + Comments + -------- + # It's modifying the source code produced by an LLM. + """ + modified_source = deepcopy(source_class_code) + for _, body in original_test_method_dict.items(): + modified_source = modified_source.replace(body, "") + modified_source = modified_source[: modified_source.rfind("}")] + for _, body in modified_test_method_dict.items(): + modified_source = modified_source + "\n" + body + modified_source = modified_source + "\n}" + return modified_source + + +def separate_assertions(source_method_code: str) -> tuple[str, str]: + """ + Separate assertions and non assertions parts + + Args: + source_method_code: test method body + + Returns: + tuple[str, str]: + assertions and non assertions parts + """ + code_split = source_method_code.splitlines() + assert_block = "" + code_block_without_assertions = "" + assertion_part = [] + query = """ + (method_invocation + name: (identifier) @method_name + ) + """ + captures: Captures = java_sitter.frame_query_and_capture_output(query, source_method_code) + for capture in captures: + if "method_name" in capture.name: + call_site_method_name = capture.node.text.decode() + if "assert" in call_site_method_name: + method_node = java_sitter.safe_ascend(capture.node, 1) + assertion_start_line = method_node.start_point[0] + assertion_end_line = method_node.end_point[0] + for i in range(assertion_start_line, assertion_end_line + 1): + assertion_part.append(i) + + for i in range(len(code_split)): + if i not in assertion_part: + code_block_without_assertions += code_split[i].strip() + "\n" + else: + assert_block += code_split[i].strip() + "\n" + return assert_block, code_block_without_assertions + + +def is_empty_test_class(source_class_code: str) -> bool: + """ + Checks if a test class has no test methods by looking for methods with @Test annotations + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + Returns: + -------- + bool + True if no test methods False if there are test methods + """ + test_methods_dict = java_sitter.get_test_methods(source_class_code) + print(test_methods_dict) + return not bool(test_methods_dict) + + +def get_all_field_access(source_class_code: str) -> Dict[str, list[list[int]]]: + """_summary_ + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + Returns: + -------- + Dict[str, [[int, int], [int, int]]] + Dictionary with field names as keys and a list of starting and ending line, and starting and ending column. + """ + query = """ + (field_access + field:(identifier) @field_name + ) + """ + captures: Captures = java_sitter.frame_query_and_capture_output(query, source_class_code) + field_dict = {} + for capture in captures: + if capture.name == "field_name": + field_name = capture.node.text.decode() + field_node = java_sitter.safe_ascend(capture.node, 2) + start_line = field_node.start_point[0] + start_column = field_node.start_point[1] + end_line = field_node.end_point[0] + end_column = field_node.end_point[1] + start_list = [start_line, start_column] + end_list = [end_line, end_column] + position = [start_list, end_list] + field_dict[field_name] = position + return field_dict + + +def get_all_fields_with_annotations(source_class_code: str) -> Dict[str, Dict]: + """ + Returns a dictionary of field names and field bodies. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + Returns: + -------- + Dict[str,Dict] + Dictionary with field names as keys and a dictionary of annotation and body as values. + """ + query = """ + (field_declaration + (variable_declarator + name: (identifier) @field_name + ) + ) + """ + captures: Captures = java_sitter.frame_query_and_capture_output(query, source_class_code) + field_dict = {} + for capture in captures: + if capture.name == "field_name": + field_name = capture.node.text.decode() + inner_dict = {} + annotation = None + field_node = java_sitter.safe_ascend(capture.node, 2) + body = field_node.text.decode() + for fc in field_node.children: + if fc.type == "modifiers": + for mc in fc.children: + if mc.type == "marker_annotation": + annotation = mc.text.decode() + inner_dict["annotation"] = annotation + inner_dict["body"] = body + field_dict[field_name] = inner_dict + return field_dict + + +def get_all_methods_with_test_with_lines(source_class_code: str) -> Dict[str, List[int]]: + """ + Returns a dictionary of method names and method bodies. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + Returns: + -------- + Dict[str,List[int]] + Dictionary with test name as keys and + starting and ending lines as values. + """ + query = """ + (method_declaration + (modifiers + (marker_annotation + name: (identifier) @annotation) + ) + ) + """ + captures: Captures = java_sitter.frame_query_and_capture_output(query, source_class_code) + annotation_method_dict = {} + for capture in captures: + if capture.name == "annotation": + annotation = capture.node.text.decode() + if "Test" in annotation: + method_node = java_sitter.safe_ascend(capture.node, 3) + method_start_line = method_node.start_point[0] + method_end_line = method_node.end_point[0] + method_name = method_node.children[2].text.decode() + annotation_method_dict[method_name] = [method_start_line, method_end_line] + return annotation_method_dict + + +def _remove_duplicates_empties(test_method_dict: dict) -> tuple[dict[Any, Any], int | Any, int | Any]: + """ + removes all the duplicates in the test methods. + + Parameters: + ----------- + code_block : str + Code block of a test method. + + Returns: + -------- + tuple[dict, int] + Dictionary of remaining methods after dedup and number of methods removed because of dedup. + """ + methods_kept = [] + num_dup_methods_removed = 0 + num_empty_methods_removed = 0 + block_set = set() + for test in test_method_dict.keys(): + block = test_method_dict[test] + # capture everything between opening and closing braces + block = block[block.find("{") + 1 : block.rfind("}")] + # remove white spaces,tabs and new lines + block = re.sub(r"[\n\t\s]*", "", block) + if not block: # empty method + num_empty_methods_removed = num_empty_methods_removed + 1 + if block in block_set: + num_dup_methods_removed = num_dup_methods_removed + 1 + else: + block_set.add(block) + methods_kept.append(test) + deduped_method_dict = {key: test_method_dict[key] for key in test_method_dict if key in methods_kept} + return ( + deduped_method_dict, + num_dup_methods_removed, + num_empty_methods_removed, + ) + + +def _compose_decomposed(self, test_method_dict: dict) -> tuple[dict, int]: + """ + merges all the test methods that only have assertions as different and rest of the code same. + + Parameters: + ----------- + code_block : str + Code block of a test method. + + Returns: + -------- + tuple[dict, int] + Dictionary of merged methods and number of methods removed because of merging. + """ + composed_test_method_dict = {} + num_merged_methods = 0 + block_minus_assert_dict = {} + for test in test_method_dict.keys(): + block = test_method_dict[test] + # capture everything between opening and closing braces + block = block[block.find("{") + 1 : block.rfind("}")] + # remove assertions and keep them aside + assert_block, block_without_assertions = self._separate_assertions(block) + if block_without_assertions in block_minus_assert_dict: + existing_test = block_minus_assert_dict[block_without_assertions] + composed_test_method_dict[existing_test] = composed_test_method_dict[existing_test] + "\n" + assert_block + num_merged_methods = num_merged_methods + 1 + else: + block_minus_assert_dict[block_without_assertions] = test + composed_test_method_dict[test] = block + # now add back opening and closing braces + composed_test_method_dict = {k: "@Test\npublic void " + k + "()\n{" + v + "\n}" for k, v in composed_test_method_dict.items()} + return composed_test_method_dict, num_merged_methods + + +def _separate_assertions(code_block: str) -> tuple[str, str]: + """ + separate assertions from the code block of a test method. + + Parameters: + ----------- + code_block : str + Code block of a test method. + + Returns: + -------- + tuple[str,str] + assertions and code block without assertions. + """ + code_block_lines = code_block.split(";") + # remove new lines and tabs, but not spaces within lines (need to keep them for assertions) + code_block_lines = [re.sub(r"[\n\t]*", "", x) for x in code_block_lines] + # strip starting and trailing spaces + code_block_lines[:] = [x.strip() for x in code_block_lines] + # remove any empty lines + code_block_lines[:] = [x for x in code_block_lines if x] + code_block_lines_without_assertions = [] + code_block_without_assertions = "" + assert_lines = [] + for line in code_block_lines: + if line.startswith("assert"): + assert_lines.append(line) + else: + # now we can remove spaces within lines + line = re.sub(r"[\s]*", "", line) + code_block_lines_without_assertions.append(line) + # put back the assertion block like it should be + assert_block = ";\n".join(assert_lines) + ";" + if len(code_block_lines_without_assertions) > 0: + code_block_without_assertions = ";".join(code_block_lines_without_assertions) + ";" + return assert_block, code_block_without_assertions + + +def dedup_and_merge(self, source_class_code: str) -> tuple[LiteralString | Any, Any, Any, int]: + """ + Returns a modified source after removing duplicate test methods and merging decomposed ones. + + Parameters: + ----------- + source_class_code : str + String containing code for a java class. + + Returns: + -------- + str + modified source after removing duplicate test methods and merging decomposed ones. + """ + + test_method_dict = java_sitter.get_test_methods(source_class_code) + ( + deduped_method_dict, + num_dup_methods_removed, + num_empty_methods_removed, + ) = _remove_duplicates_empties(test_method_dict) + merged_method_dict, num_methods_merged = _compose_decomposed(deduped_method_dict) + modified_source = _replace_in_source(test_class, test_method_dict, merged_method_dict) + return ( + modified_source, + num_dup_methods_removed, + num_empty_methods_removed, + num_methods_merged, + ) + + +# TODO: This has to be moved to the test file! +if __name__ == "__main__": + assert_code_check = """ + { + StringBuffer result = helpFormatter.renderOptions(sb, width, options, leftPad, descPad); + assertEquals(result.toString(), "StringBuffer sb = new StringBuffer();\n" + + "int width = 100;\n" + + "int leftPad = 1;\n" + + "int descPad = 1;\n" + + "StringBuffer result = helpFormatter.renderOptions(sb, width, options, leftPad, descPad);\n" + + "assertEquals(result.toString(), \"StringBuffer sb = new StringBuffer();\");\n"); + } + """ + java_code = """ + public class HelloWorld { + public static void main(String[] args) { + System.out.println("Hello, World!"); + } + } + """ + # tokens = ts1.get_lexical_tokens(code=java_code, filter_by_node_type=['identifier']) + # print(tokens) + generated_test_code = """ + @Test + public void feedDog_mergecandidate(){ + myDog.setWeight(myDog.name); + myDog.eat(); + assertEquals("Error2 when eating", 10, myDog.getWeight()); + } + """ + source_method = """ public void method1() { + System.out.println("Start of method1"); + // Call Method2 at random places + method2(); + System.out.println("Called method2 first time"); + method2(); + System.out.println("Called method2 second time"); + method2(); + System.out.println("Called method2 last time"); + } + """ + test_class = """ package acd; + import a; + public class DogTest { + private ModResortsCustomerInformation modResortsCustomerInformation; + @mock + private DataSource dataSource; + @mock + private Connection connection; + private PreparedStatement preparedStatement; + private ResultSet resultSet; + Dog myDog; + @Before + public void setUp(){ + System.out.println("This is run before"); + myDog = new Dog("Jimmy", "Beagle"); + } + @After + public void tearDown(){ + System.out.println("This is run after"); + } + @Test + public void createNewDog(){ + assertEquals("Error in creating a dog", "Jimmy", myDog.getName()); + } + @Test + public void feedDog(){ + myDog.setWeight(5); + myDog.eat(); + assertEquals("Error when eating", 10, myDog.getWeight()); + } + @Test + public void feedDog_mergecandidate(){ + myDog.setWeight(5); + myDog.eat(); + assertEquals("Error2 when eating", 10, myDog.getWeight()); + } + @Test + public void createNewDog_mergecandidate(){ + assertEquals("another error in creating a dog", "Jimmy", myDog.getName()); + } + @Test + public void feedDog2(){ + myDog.setWeight(5); + + myDog.eat(); + assertEquals("Error when eating", 10, myDog.getWeight()); + } + @Test + public void feedDog3(){ + myDog.setWeight(5); + myDog.eat(); + assertEquals("Error when eating", 10, myDog.getWeight()); + } + @Test + public void createNewDog2(){ + + assertEquals("Error in creating a dog", "Jimmy", myDog.getName()); + } + @Test + public void emptyTest(){ + + + } + } + """ + empty_test_class = """ + public class DogTest { + private ModResortsCustomerInformation modResortsCustomerInformation; + @mock + private DataSource dataSource; + @mock + private Connection connection; + private PreparedStatement preparedStatement; + private ResultSet resultSet; + Dog myDog; + @Before + public void setUp(){ + System.out.println("This is run before"); + myDog = new Dog("Jimmy", "Beagle"); + } + @After + public void tearDown(){ + System.out.println("This is run after"); + } + """ + # target_method_name = "public void method2()" + print(separate_assertions(source_method_code=assert_code_check)) + # print(ts.get_calling_lines(source_method,target_method_name)) + # print(ts.get_all_methods_with_test_with_lines(test_class)) + # print(ts.get_test_methods(test_class)) + # print(ts.get_all_field_access(generated_test_code)) + # test_method_dict = ts.get_test_methods(test_class) + # deduped_method_dict, num_methods_removed = ts._remove_duplicates(test_method_dict) + # print("num_methods_removed ",num_methods_removed) + # merged_method_dict, num_merged_methods = ts._compose_decomposed(deduped_method_dict) + # print("merged_methods ",num_merged_methods) + # print(test_method_dict) + # print(ts._replace_in_source(test_class,test_method_dict,merged_method_dict)) + # print(ts.get_all_methods_with_annotations(test_class, ["Test", "Before"])) + # print(ts.get_all_fields_with_annotations(test_class)) + # print(ts.dedup_and_merge(test_class)) + # print(ts.is_empty_test_class(empty_test_class)) diff --git a/cldk/utils/sanitization/java/__init__.py b/cldk/utils/sanitization/java/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/cldk/utils/sanitization/java/__init__.py @@ -0,0 +1 @@ + diff --git a/cldk/utils/treesitter/__init__.py b/cldk/utils/treesitter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/cldk/utils/treesitter/tree_sitter_utils.py b/cldk/utils/treesitter/tree_sitter_utils.py new file mode 100644 index 0000000..8621232 --- /dev/null +++ b/cldk/utils/treesitter/tree_sitter_utils.py @@ -0,0 +1,48 @@ +from tree_sitter import Query, Node + +from cldk.models.treesitter import Captures + + +class TreeSitterUtils: + def __frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures: + """Frame a query for the tree-sitter parser. + + Parameters + ---------- + query : str + The query to frame. + code_to_process : str + The code to process. + """ + framed_query: Query = self.language.query(query) + tree = self.parser.parse(bytes(code_to_process, "utf-8")) + return Captures(framed_query.captures(tree.root_node)) + + def __safe_ascend(self, node: Node, ascend_count: int) -> Node: + """Safely ascend the tree. If the node does not exist or if it has no parent, raise an error. + + Parameters + ---------- + node : Node + The node to ascend from. + ascend_count : int + The number of times to ascend the tree. + + Returns + ------- + Node + The node at the specified level of the tree. + + Raises + ------ + ValueError + If the node has no parent. + """ + if node is None: + raise ValueError("Node does not exist.") + if node.parent is None: + raise ValueError("Node has no parent.") + if ascend_count == 0: + return node + else: + return self.__safe_ascend(node.parent, ascend_count - 1) \ No newline at end of file diff --git a/cldk.png b/docs/assets/cldk.png similarity index 100% rename from cldk.png rename to docs/assets/cldk.png diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..9bba721 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,49 @@ +[tool.poetry] +name = "cldk" +version = "0.1.0-dev" +description = "codellm-devkit: A python library for seamless interation with LLMs." +authors = ["Rahul Krishna ", "Rangeet Pan ", "Saurabh Sinhas ", + "Raju Pavuluri "] +license = "Apache 2.0" +readme = "README.md" +include = ["cldk/analysis/java/codeanalyzer/jar/*.jar"] + +[tool.poetry.dependencies] +python = ">=3.11" +pydantic = "^2.6.1" +pandas = "^2.2.0" +networkx = "^3.2.1" +pyarrow = "^15.0.0" +tree-sitter-languages = "^1.10.2" +tree-sitter = "^0.22.3" +rich = "^13.7.1" +wget = "^3.2" +requests = "^2.31.0" +tree-sitter-java = "^0.21.0" +tree-sitter-c = "^0.21.0" +tree-sitter-go = "^0.21.0" +tree-sitter-python = {git = "https://github.com/tree-sitter/tree-sitter-python", rev = "0f9047c"} # Points to 0.21.0 +tree-sitter-javascript = "^0.21.0" +# Test dependencies + +[tool.poetry.group.dev.dependencies] +toml = "^0.10.2" +pytest = "^7.4.3" +ipdb = "^0.13.13" +jupyter = "^1.0.0" + +# Documentation +mkdocs = "1.6.0" +mkdocstrings = "0.25.1" + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[tool.black] +line-length = 180 + +[tool.cldk.testing] +sample-application = "tests/resources/java/application/" +sample-application-analysis-json = "tests/resources/java/analysis_db" +codeanalyzer-jar-path = "tests/resources/java/codeanalyzer/build/libs/" \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/__init__.py b/tests/analysis/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/java/__init__.py b/tests/analysis/java/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/java/test_java.py b/tests/analysis/java/test_java.py new file mode 100644 index 0000000..01fcf2f --- /dev/null +++ b/tests/analysis/java/test_java.py @@ -0,0 +1,20 @@ +from typing import List, Tuple +from cldk import CLDK +from cldk.models.java.models import JMethodDetail + + +def test_get_class_call_graph(test_fixture): + # Initialize the CLDK object with the project directory, language, and analysis_backend. + cldk = CLDK(language="java") + + analysis = cldk.analysis( + project_path=test_fixture, + analysis_backend="codeanalyzer", + analysis_json_path="/tmp", + eager=True, + ) + class_call_graph: List[Tuple[JMethodDetail, JMethodDetail]] = analysis.get_class_call_graph( + qualified_class_name="com.ibm.websphere.samples.daytrader.impl.direct.TradeDirectDBUtils" + ) + + assert class_call_graph is not None diff --git a/tests/analysis/python/__init__.py b/tests/analysis/python/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/python/test_python.py b/tests/analysis/python/test_python.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/analysis/test_codeql.py b/tests/analysis/test_codeql.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..d5182ed --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,50 @@ +import toml +import shutil +import pytest +import zipfile +from pathlib import Path +from urllib.request import urlretrieve + + +@pytest.fixture(scope="session", autouse=True) +def test_fixture(): + """ + Returns the path to the test data directory. + + Yields: + Path : The path to the test data directory. + """ + # ----------------------------------[ SETUP ]---------------------------------- + # Path to your pyproject.toml + pyproject_path = Path(__file__).parent.parent / "pyproject.toml" + + # Load the configuration + config = toml.load(pyproject_path) + + # Access the test data path + test_data_path = config["tool"]["cldk"]["testing"]["sample-application"] + + if not Path(test_data_path).exists(): + Path(test_data_path).mkdir(parents=True) + + url = "https://github.com/OpenLiberty/sample.daytrader8/archive/refs/tags/v1.2.zip" + filename = Path(test_data_path).absolute() / "v1.2.zip" + urlretrieve(url, filename) + + # Extract the zip file to the test data path + with zipfile.ZipFile(filename, "r") as zip_ref: + zip_ref.extractall(test_data_path) + + # Remove the zip file + filename.unlink() + # -------------------------------------------------------------------------------- + + # Daytrader8 sample application path + yield Path(test_data_path) / "sample.daytrader8-1.2" + + # -----------------------------------[ TEARDOWN ]---------------------------------- + # Remove the daytrader8 sample application that was downloaded for testing + for directory in Path(test_data_path).iterdir(): + if directory.exists() and directory.is_dir(): + shutil.rmtree(directory) + # --------------------------------------------------------------------------------- diff --git a/tests/example.py b/tests/example.py new file mode 100644 index 0000000..ad612e0 --- /dev/null +++ b/tests/example.py @@ -0,0 +1,98 @@ +"""Example: Use CLDK to build a code summarization model +""" + +from cldk import CLDK +from cldk.analysis.java import JavaAnalysis + +# Initialize the Codellm-DevKit object with the project directory, language, and analysis_backend. +ns = CLDK( + project_dir="/Users/rajupavuluri/development/sample.daytrader8/", + language="java", + analysis_json_path="/Users/rkrsn/Downloads/sample.daytrader8/", +) + + +# Get the java application view for the project. +java_analysis: JavaAnalysis = ns.get_analysis() + +classes_dict = ns.preprocessing.get_classes() +# print(classes_dict) +entry_point_classes_dict = ns.preprocessing.get_entry_point_classes() +print(entry_point_classes_dict) + +entry_point_methods_dict = ns.preprocessing.get_entry_point_methods() +print(entry_point_methods_dict) + + +# ##get the first class in this dictionary for testing purposes +test_class_name = next(iter(classes_dict)) +print(test_class_name) +test_class = classes_dict[test_class_name] +# print(test_class) +# print(test_class.is_entry_point) + +# constructors = ns.preprocessing.get_all_constructors(test_class_name) +# print(constructors) + +# fields = ns.preprocessing.get_all_fields(test_class_name) +# print("fields :", fields) + +# methods = ns.preprocessing.get_all_methods_in_class(test_class_name) +# # print("number of methods in class ",test_class_name, ": ",len(methods)) +# nested_classes = ns.preprocessing.get_all_nested_classes(test_class_name) +# # print("nested_classes: ",nested_classes) +# extended_classes = ns.preprocessing.get_extended_classes(test_class_name) +# # print("extended_classes: ",extended_classes) +# implemented_interfaces = ns.preprocessing.get_implemented_interfaces( +# test_class_name +# ) +# # print("implemented_interfaces: ",implemented_interfaces) +# class_result = ns.preprocessing.get_class(test_class_name) +# print("class_result: ", class_result) +# java_file_name = ns.preprocessing.get_java_file(test_class_name) +# # print("java_file_name ",java_file_name) +# all_methods = ns.preprocessing.get_all_methods_in_application() +# # print(all_methods) +# method = ns.preprocessing.get_method( +# "com.ibm.websphere.samples.daytrader.util.Log", +# "public static void trace(String message)", +# ) +# print(method) +# # Get the call graph. + +# cg = ns.preprocessing.get_call_graph() +# print(cg) +# # print(ns.preprocessing.get_call_graph_json()) + +# # print(cg.edges) +# # d = ns.preprocessing.get_all_callers("com.ibm.websphere.samples.daytrader.util.Log","public static void trace(String message)") +# # print("caller details::") +# # print(d) +# # v = ns.preprocessing.get_all_callees("com.ibm.websphere.samples.daytrader.impl.ejb3.MarketSummarySingleton","private void updateMarketSummary()") +# # print("callee details::") +# # print(v) + +# """ +# # Get the user specified method. +# method: JCallable = app.get_method("com.example.foo.Bar.baz") # <- User specified method. + +# # Get the slices that contain the method. +# slices: nx.Generator = ns.preprocessing.get_slices_containing_method(method, sdg=app.sdg) + +# # Optional: Get samples for RAG from (say) elasticsearch +# few_shot_samples: List[str] = ns.prompting.rag( +# database={"hostname": "https://localhost:9200", "index": "summarization"} +# ).retrive_few_shot_samples(method=method, slices=slices) + +# # Natively we'll support PDL as the prompting engine to get summaries from the LLM. + +# summaries: List[str] = ns.prompting(engine="pdl").summarize(method, context=slices, few_shot_samples=few_shot_samples) + +# # Optionally, we will also support other open-source engines such as LMQL, Guidance, user defined Jinja, etc. +# summaries: List[str] = ns.prompting(engine="lmql").summarize(slices=slices, few_shot_samples=few_shot_samples) +# summaries: List[str] = ns.prompting(engine="guidance").summarize(slices=slices, few_shot_samples=few_shot_samples) +# summaries: List[str] = ns.prompting(engine="jinja", template="