Skip to content

Reads slim JSON from codeanalyzer v2.0.0. #59

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions cldk/analysis/analysis_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class AnalysisLevel(str, Enum):
"""Analysis levels"""

symbol_table = "symbol-table"
call_graph = "call-graph"
program_dependency_graph = "program-dependency-graph"
system_dependency_graph = "system-dependency-graph"
symbol_table = "symbol table"
call_graph = "call graph"
program_dependency_graph = "program dependency graph"
system_dependency_graph = "system dependency graph"
74 changes: 10 additions & 64 deletions cldk/analysis/java/codeanalyzer/codeanalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
# limitations under the License.
################################################################################

"""
Codeanalyzer module
"""

import re
import json
import shlex
Expand Down Expand Up @@ -120,61 +116,6 @@ def __init__(
else:
self.call_graph: DiGraph | None = None

@staticmethod
def _download_or_update_code_analyzer(filepath: Path) -> str:
"""Downloads the codeanalyzer jar from the latest release on GitHub.

Args:
filepath (Path): The path to save the codeanalyzer jar.

Returns:
str: The path to the downloaded codeanalyzer jar file.
"""
url = "https://api.github.com/repos/IBM/codenet-minerva-code-analyzer/releases/latest"
response = requests.get(url)
date_format = "%Y%m%dT%H%M%S"
if response.status_code == 200:
for asset in response.json().get("assets", []):
if asset["name"] == "codeanalyzer.jar":
download_url = asset["browser_download_url"]
pattern = r"(\d{8}T\d{6})"
match = re.search(pattern, download_url)
if match:
datetime_str = match.group(0)
else:
raise Exception(f"Release URL {download_url} does not contain a datetime pattern.")

# Look for codeanalyzer.YYYYMMDDTHHMMSS.jar in the filepath
current_codeanalyzer_jars = [jarfile for jarfile in filepath.glob("*.jar")]
if not any(current_codeanalyzer_jars):
logger.info(f"Codeanalzyer jar is not found. Downloading the latest version.")
filename = filepath / f"codeanalyzer.{datetime_str}.jar"
urlretrieve(download_url, filename)
return filename.__str__()

current_codeanalyzer_jar_name = current_codeanalyzer_jars[0]
match = re.search(pattern, current_codeanalyzer_jar_name.__str__())
if match:
current_datetime_str = match.group(0)

if datetime.strptime(datetime_str, date_format) > datetime.strptime(current_datetime_str, date_format):
logger.info(f"Codeanalzyer jar is outdated. Downloading the latest version.")
# Remove the older codeanalyzer jar
for jarfile in current_codeanalyzer_jars:
jarfile.unlink()
# Download the newer codeanalyzer jar
filename = filepath / f"codeanalyzer.{datetime_str}.jar"
urlretrieve(download_url, filename)
else:
filename = current_codeanalyzer_jar_name
logger.info(f"Codeanalzyer jar is already at the latest version.")
else:
filename = current_codeanalyzer_jar_name

return filename.__str__()
else:
raise Exception(f"Failed to fetch release warn: {response.status_code} {response.text}")

def _get_application(self) -> JApplication:
"""Returns the application view of the Java code.

Expand Down Expand Up @@ -204,14 +145,15 @@ def _get_codeanalyzer_exec(self) -> List[str]:

if self.analysis_backend_path:
analysis_backend_path = Path(self.analysis_backend_path)
logger.info(f"Using codeanalyzer.jar from {analysis_backend_path}")
codeanalyzer_exec = shlex.split(f"java -jar {analysis_backend_path / 'codeanalyzer.jar'}")
logger.info(f"Using codeanalyzer jar from {analysis_backend_path}")
codeanalyzer_jar_file = next(analysis_backend_path.rglob("codeanalyzer-*.jar"), None)
codeanalyzer_exec = shlex.split(f"java -jar {codeanalyzer_jar_file}")
else:
# Since the path to codeanalyzer.jar was not provided, we'll download the latest version from GitHub.
with resources.as_file(resources.files("cldk.analysis.java.codeanalyzer.jar")) as codeanalyzer_jar_path:
# Download the codeanalyzer jar if it doesn't exist, update if it's outdated,
# do nothing if it's up-to-date.
codeanalyzer_jar_file = self._download_or_update_code_analyzer(codeanalyzer_jar_path)
codeanalyzer_jar_file = next(codeanalyzer_jar_path.rglob("codeanalyzer-*.jar"), None)
codeanalyzer_exec = shlex.split(f"java -jar {codeanalyzer_jar_file}")
return codeanalyzer_exec

Expand Down Expand Up @@ -372,11 +314,15 @@ def _generate_call_graph(self, using_symbol_table) -> DiGraph:
{
"type": jge.type,
"weight": jge.weight,
"calling_lines": tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature),
"calling_lines": (
tsu.get_calling_lines(jge.source.method.code, jge.target.method.signature, jge.target.method.is_constructor)
if not jge.source.method.is_implicit or not jge.target.method.is_implicit
else []
),
},
)
for jge in sdg
if jge.type == "CONTROL_DEP" or jge.type == "CALL_DEP"
if jge.type == "CALL_DEP" # or jge.type == "CONTROL_DEP"
]
for jge in sdg:
cg.add_node(
Expand Down
82 changes: 43 additions & 39 deletions cldk/analysis/java/treesitter/javasitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""
JavaSitter module
"""

from itertools import groupby
from typing import List, Set, Dict
from tree_sitter import Language, Node, Parser, Query, Tree
Expand All @@ -26,6 +25,10 @@

from cldk.models.treesitter import Captures

import logging

logger = logging.getLogger(__name__)


class JavaSitter:
"""
Expand All @@ -51,8 +54,7 @@ def method_is_not_in_class(self, method_name: str, class_body: str) -> bool:
bool
True if the method is in the class, False otherwise.
"""
methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)",
class_body)
methods_in_class = self.frame_query_and_capture_output("(method_declaration name: (identifier) @name)", class_body)

return method_name not in {method.node.text.decode() for method in methods_in_class}

Expand Down Expand Up @@ -103,8 +105,7 @@ def get_all_imports(self, source_code: str) -> Set[str]:
Returns:
Set[str]: A set of all the imports in the class.
"""
import_declerations: Captures = self.frame_query_and_capture_output(
query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code)
import_declerations: Captures = self.frame_query_and_capture_output(query="(import_declaration (scoped_identifier) @name)", code_to_process=source_code)
return {capture.node.text.decode() for capture in import_declerations}

def get_pacakge_name(self, source_code: str) -> str:
Expand All @@ -116,8 +117,7 @@ def get_pacakge_name(self, source_code: str) -> str:
Returns:
str: The package name.
"""
package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)",
code_to_process=source_code)
package_name: Captures = self.frame_query_and_capture_output(query="((package_declaration) @name)", code_to_process=source_code)
if package_name:
return package_name[0].node.text.decode().replace("package ", "").replace(";", "")
return None
Expand All @@ -143,8 +143,7 @@ def get_superclass(self, source_code: str) -> str:
Returns:
Set[str]: A set of all the superclasses in the class.
"""
superclass: Captures = self.frame_query_and_capture_output(
query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code)
superclass: Captures = self.frame_query_and_capture_output(query="(class_declaration (superclass (type_identifier) @superclass))", code_to_process=source_code)

if len(superclass) == 0:
return ""
Expand All @@ -161,9 +160,7 @@ def get_all_interfaces(self, source_code: str) -> Set[str]:
Set[str]: A set of all the interfaces implemented by the class.
"""

interfaces = self.frame_query_and_capture_output(
"(class_declaration (super_interfaces (type_list (type_identifier) @interface)))",
code_to_process=source_code)
interfaces = self.frame_query_and_capture_output("(class_declaration (super_interfaces (type_list (type_identifier) @interface)))", code_to_process=source_code)
return {interface.node.text.decode() for interface in interfaces}

def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Captures:
Expand All @@ -182,8 +179,7 @@ def frame_query_and_capture_output(self, query: str, code_to_process: str) -> Ca

def get_method_name_from_declaration(self, method_name_string: str) -> str:
"""Get the method name from the method signature."""
captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)",
method_name_string)
captures: Captures = self.frame_query_and_capture_output("(method_declaration name: (identifier) @method_name)", method_name_string)

return captures[0].node.text.decode()

Expand All @@ -192,8 +188,12 @@ def get_method_name_from_invocation(self, method_invocation: str) -> str:
Using the tree-sitter query, extract the method name from the method invocation.
"""

captures: Captures = self.frame_query_and_capture_output(
"(method_invocation object: (identifier) @class_name name: (identifier) @method_name)", method_invocation)
captures: Captures = self.frame_query_and_capture_output("(method_invocation name: (identifier) @method_name)", method_invocation)
return captures[0].node.text.decode()

def get_identifier_from_arbitrary_statement(self, statement: str) -> str:
"""Get the identifier from an arbitrary statement."""
captures: Captures = self.frame_query_and_capture_output("(identifier) @identifier", statement)
return captures[0].node.text.decode()

def safe_ascend(self, node: Node, ascend_count: int) -> Node:
Expand Down Expand Up @@ -260,7 +260,7 @@ def get_call_targets(self, method_body: str, declared_methods: dict) -> Set[str]
)
return call_targets

def get_calling_lines(self, source_method_code: str, target_method_name: str) -> List[int]:
def get_calling_lines(self, source_method_code: str, target_method_name: str, is_target_method_a_constructor: bool) -> List[int]:
"""
Returns a list of line numbers in source method where target method is called.

Expand All @@ -272,26 +272,34 @@ def get_calling_lines(self, source_method_code: str, target_method_name: str) ->
target_method_code : str
target method code

is_target_method_a_constructor : bool
True if target method is a constructor, False otherwise.

Returns:
--------
List[int]
List of line numbers within in source method code block.
"""
query = "(method_invocation name: (identifier) @method_name)"
if not source_method_code:
return []
query = "(object_creation_expression (type_identifier) @object_name) (object_creation_expression type: (scoped_type_identifier (type_identifier) @type_name)) (method_invocation name: (identifier) @method_name)"

# if target_method_name is a method signature, get the method name
# if it is not a signature, we will just keep the passed method name

target_method_name = target_method_name.split("(")[0] # remove the arguments from the constructor name
try:
target_method_name = self.get_method_name_from_declaration(target_method_name)
except Exception:
pass

captures: Captures = self.frame_query_and_capture_output(query, source_method_code)
# Find the line numbers where target method calls happen in source method
target_call_lines = []
for c in captures:
method_name = c.node.text.decode()
if method_name == target_method_name:
target_call_lines.append(c.node.start_point[0])
captures: Captures = self.frame_query_and_capture_output(query, source_method_code)
# Find the line numbers where target method calls happen in source method
target_call_lines = []
for c in captures:
method_name = c.node.text.decode()
if method_name == target_method_name:
target_call_lines.append(c.node.start_point[0])
except:
logger.warning(f"Unable to get calling lines for {target_method_name} in {source_method_code}.")
return []

return target_call_lines

def get_test_methods(self, source_class_code: str) -> Dict[str, str]:
Expand Down Expand Up @@ -398,8 +406,7 @@ def get_method_return_type(self, source_code: str) -> str:
The return type of the method.
"""

type_references: Captures = self.frame_query_and_capture_output(
"(method_declaration type: ((type_identifier) @type_id))", source_code)
type_references: Captures = self.frame_query_and_capture_output("(method_declaration type: ((type_identifier) @type_id))", source_code)

return type_references[0].node.text.decode()

Expand All @@ -426,9 +433,9 @@ def collect_leaf_token_values(node):
if len(node.children) == 0:
if filter_by_node_type is not None:
if node.type in filter_by_node_type:
lexical_tokens.append(code[node.start_byte: node.end_byte])
lexical_tokens.append(code[node.start_byte : node.end_byte])
else:
lexical_tokens.append(code[node.start_byte: node.end_byte])
lexical_tokens.append(code[node.start_byte : node.end_byte])
else:
for child in node.children:
collect_leaf_token_values(child)
Expand Down Expand Up @@ -462,11 +469,9 @@ def remove_all_comments(self, source_code: str) -> str:
pruned_source_code = self.make_pruned_code_prettier(source_code)

# Remove all comment lines: the comment lines start with / (for // and /*) or * (for multiline comments).
comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)",
code_to_process=source_code)
comment_blocks: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=source_code)

comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)",
code_to_process=source_code)
comment_lines: Captures = self.frame_query_and_capture_output(query="((line_comment) @comment_line)", code_to_process=source_code)

for capture in comment_blocks:
pruned_source_code = pruned_source_code.replace(capture.node.text.decode(), "")
Expand All @@ -490,8 +495,7 @@ def make_pruned_code_prettier(self, pruned_code: str) -> str:
The prettified pruned code.
"""
# First remove remaining block comments
block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)",
code_to_process=pruned_code)
block_comments: Captures = self.frame_query_and_capture_output(query="((block_comment) @comment_block)", code_to_process=pruned_code)

for capture in block_comments:
pruned_code = pruned_code.replace(capture.node.text.decode(), "")
Expand Down
38 changes: 28 additions & 10 deletions cldk/models/java/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
"""
Models module
"""

import re
from contextvars import ContextVar
from typing import Dict, List, Optional
Expand Down Expand Up @@ -64,7 +63,7 @@ class JCallableParameter(BaseModel):
modifiers (List[str]): The modifiers applied to the parameter.
"""

name: str
name: str | None
type: str
annotations: List[str]
modifiers: List[str]
Expand Down Expand Up @@ -361,10 +360,30 @@ class JGraphEdges(BaseModel):
@field_validator("source", "target", mode="before")
@classmethod
def validate_source(cls, value) -> JMethodDetail:
file_path, type_declaration, callable_declaration = value["file_path"], value["type_declaration"], value["callable_declaration"]
j_callable = _CALLABLES_LOOKUP_TABLE.get((file_path, type_declaration, callable_declaration), None)
if j_callable is None:
raise ValueError(f"Callable not found in lookup table: {file_path}, {type_declaration}, {callable_declaration}")
_, type_declaration, signature = value["file_path"], value["type_declaration"], value["signature"]
j_callable = _CALLABLES_LOOKUP_TABLE.get(
(type_declaration, signature),
JCallable(
signature=signature,
is_implicit=True,
is_constructor="<init>" in value["callable_declaration"],
comment="",
annotations=[],
modifiers=[],
thrown_exceptions=[],
declaration="",
parameters=[JCallableParameter(name=None, type=t, annotations=[], modifiers=[]) for t in value["callable_declaration"].split("(")[1].split(")")[0].split(",")],
code="",
start_line=-1,
end_line=-1,
referenced_types=[],
accessed_fields=[],
call_sites=[],
variable_declarations=[],
cyclomatic_complexity=0,
),
)
_CALLABLES_LOOKUP_TABLE[(type_declaration, signature)] = j_callable
class_name = type_declaration
method_decl = j_callable.declaration
return JMethodDetail(method_declaration=method_decl, klass=class_name, method=j_callable)
Expand All @@ -391,9 +410,8 @@ class JApplication(BaseModel):
@field_validator("symbol_table", mode="after")
@classmethod
def validate_source(cls, symbol_table):

# Populate the lookup table for callables
for file_path, j_compulation_unit in symbol_table.items():
for _, j_compulation_unit in symbol_table.items():
for type_declaration, jtype in j_compulation_unit.type_declarations.items():
for callable_declaration, j_callable in jtype.callable_declarations.items():
_CALLABLES_LOOKUP_TABLE[(file_path, type_declaration, callable_declaration)] = j_callable
for __, j_callable in jtype.callable_declarations.items():
_CALLABLES_LOOKUP_TABLE[(type_declaration, j_callable.signature)] = j_callable
Loading