From d781b511e4a0635442b541bb4be883c5fe3f4fba Mon Sep 17 00:00:00 2001 From: Rangeet Pan Date: Mon, 14 Apr 2025 11:02:45 -0400 Subject: [PATCH] Support preliminary details of conditional statements #125 --- .../commons/treesitter/treesitter_java.py | 58 +++++++++++++++++++ tests/analysis/java/test_java_sitter.py | 50 +++++++++++++--- 2 files changed, 99 insertions(+), 9 deletions(-) diff --git a/cldk/analysis/commons/treesitter/treesitter_java.py b/cldk/analysis/commons/treesitter/treesitter_java.py index 3c0033d..bc2eb7d 100644 --- a/cldk/analysis/commons/treesitter/treesitter_java.py +++ b/cldk/analysis/commons/treesitter/treesitter_java.py @@ -94,6 +94,64 @@ def get_raw_ast(self, code: str) -> Tree: """ return self.parser.parse(bytes(code, "utf-8")) + def get_all_conditional_statements(self, source_code: str) -> List[dict]: + """ + Get all conditional statements given a Java code + Args: + source_code: + + Returns: + List[dict]: Each element represents a conditional statement. Key start line number, + value : {"condition":, "code":, "start_line":, "end_line":, + "condition_type":[FOR|WHILE|DOWHILE|IF|ELSE|CATCH]} + """ + tree = self.parser.parse(source_code.encode('utf-8')) + root = tree.root_node + + conditionals = [] + + def traverse(node): + condition_type = None + condition_node = None + if node.type == 'if_statement': + condition_type = 'IF' + condition_node = node.child_by_field_name('condition') + elif node.type == 'for_statement': + condition_type = 'FOR' + condition_node = node.child_by_field_name('condition') or node + elif node.type == 'while_statement': + condition_type = 'WHILE' + condition_node = node.child_by_field_name('condition') + elif node.type == 'do_statement': + condition_type = 'DOWHILE' + condition_node = node.child_by_field_name('condition') + elif node.type == 'catch_clause': + condition_type = 'CATCH' + condition_node = node.child_by_field_name('parameter') + elif node.type == 'else_clause': + condition_type = 'ELSE' + condition_node = None + + if condition_type: + condition_text = str(self.__get_node_text(source_code, condition_node)) if condition_node else None + code_text = str(self.__get_node_text(source_code, node), 'utf-8') + conditionals.append({ + "condition": condition_text, + "code": code_text, + "start_line": node.start_point[0] + 1, + "end_line": node.end_point[0] + 1, + "condition_type": condition_type + }) + + for child in node.children: + traverse(child) + + traverse(root) + return conditionals + + def __get_node_text(self, source_code, node): + return source_code[node.start_byte:node.end_byte].encode('utf-8') + def get_all_imports(self, source_code: str) -> Set[str]: """Get a list of all the imports in a class. diff --git a/tests/analysis/java/test_java_sitter.py b/tests/analysis/java/test_java_sitter.py index 73ae639..6b2712b 100644 --- a/tests/analysis/java/test_java_sitter.py +++ b/tests/analysis/java/test_java_sitter.py @@ -30,7 +30,8 @@ def test_method_is_not_in_class(test_fixture): java_sitter = TreesitterJava() # Get a test source file and send its contents - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/beans/MarketSummaryDataBean.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/beans/MarketSummaryDataBean.java") with open(filename, "r", encoding="utf-8") as file: class_body = file.read() @@ -48,7 +49,8 @@ def test_is_parsable(test_fixture): java_sitter = TreesitterJava() # Get a test source file and send its contents - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/beans/MarketSummaryDataBean.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/beans/MarketSummaryDataBean.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -161,7 +163,8 @@ def test_get_all_interfaces(test_fixture): java_sitter = TreesitterJava() # Get a test source file with interfaces - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/impl/direct/TradeDirect.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/impl/direct/TradeDirect.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -261,6 +264,29 @@ def test_get_call_targets(): # TODO: This test case needs to be written +def test_conditional_statement(): + """conditional statement""" + java_sitter = TreesitterJava() + + source_method_code = """ + public static BigDecimal computeHoldingsTotal(Collection holdingDataBeans) { + BigDecimal holdingsTotal = new BigDecimal(0.0).setScale(SCALE); + if (holdingDataBeans == null) { + return holdingsTotal; + } + Iterator it = holdingDataBeans.iterator(); + while (it.hasNext()) { + HoldingDataBean holdingData = (HoldingDataBean) it.next(); + BigDecimal total = holdingData.getPurchasePrice().multiply(new BigDecimal(holdingData.getQuantity())); + holdingsTotal = holdingsTotal.add(total); + } + return holdingsTotal.setScale(SCALE); + } + """ + conditional_statements = java_sitter.get_all_conditional_statements(source_code=source_method_code) + assert len(conditional_statements) > 0 + + def test_get_calling_lines(): """get the calling lines""" java_sitter = TreesitterJava() @@ -301,7 +327,8 @@ def test_get_test_methods(test_fixture): # TODO: Need to find an example with test methods # Get a test source file with interfaces - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/impl/direct/TradeDirect.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/impl/direct/TradeDirect.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -316,7 +343,8 @@ def test_get_methods_with_annotations(test_fixture): java_sitter = TreesitterJava() # Get a test source file with annotations - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -337,7 +365,8 @@ def test_get_all_type_invocations(test_fixture): java_sitter = TreesitterJava() # Get a test source file - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -380,7 +409,8 @@ def test_get_lexical_tokens(test_fixture): java_sitter = TreesitterJava() # Get a test source file - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -396,7 +426,8 @@ def test_remove_all_comments(test_fixture): java_sitter = TreesitterJava() # Get a test source file - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") with open(filename, "r", encoding="utf-8") as file: code = file.read() @@ -413,7 +444,8 @@ def test_make_pruned_code_prettier(test_fixture): java_sitter = TreesitterJava() # Get a test source file - filename = os.path.join(test_fixture, "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") + filename = os.path.join(test_fixture, + "src/main/java/com/ibm/websphere/samples/daytrader/web/prims/PingJDBCRead2JSP.java") with open(filename, "r", encoding="utf-8") as file: code = file.read()