Skip to content

Update get entrypoint classes and methods #85

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cldk/analysis/c/clang/clang_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@
from clang.cindex import Config
from pathlib import Path
from typing import List, Optional
from cldk.models.c import CFunction, CMacro, CCallSite, CTranslationUnit, CApplication
from cldk.models.c import CFunction, CCallSite, CTranslationUnit, CApplication
import logging
from ipdb import set_trace

from cldk.models.c.models import CInclude, CParameter, CVariable, StorageClass

Expand Down Expand Up @@ -34,14 +33,15 @@ def find_libclang() -> str:

# On Linux, we check various common installation paths
elif system == "Linux":
from pathlib import Path

lib_paths = [Path("/usr/lib"), Path("/usr/lib64")]
possible_paths = [
"/usr/lib/llvm-14/lib/libclang.so",
"/usr/lib/llvm-13/lib/libclang.so",
"/usr/lib/llvm-12/lib/libclang.so",
"/usr/lib/x86_64-linux-gnu/libclang-14.so.1",
"/usr/lib/libclang.so",
str(p) for base in lib_paths if base.exists()
for p in base.rglob("libclang*.so*")
]
install_instructions = "Install libclang using: sudo apt-get install libclang-dev"

install_instructions = "Install libclang development package using your system's package manager"
else:
raise RuntimeError(f"Unsupported operating system: {system}")

Expand Down
33 changes: 17 additions & 16 deletions cldk/analysis/java/codeanalyzer/codeanalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from itertools import chain, groupby
from pdb import set_trace
import re
import json
Expand Down Expand Up @@ -863,20 +864,20 @@ def get_class_call_graph(self, qualified_class_name: str, method_name: str | Non
return graph_edges

def get_all_entry_point_methods(self) -> Dict[str, Dict[str, JCallable]]:
"""Returns a dictionary of all entry point methods in the Java code with
qualified class name as the key and a dictionary of methods in that class as the value.
"""Returns a dictionary of all entry point methods in the Java code.

Returns:
Dict[str, Dict[str, JCallable]]: A dictionary of dictionaries of entry point
methods in the Java code.
Dict[str, Dict[str, JCallable]]: A dictionary of all entry point methods in the Java code.
"""

class_method_dict = {}
class_dict = self.get_all_classes()
for k, v in class_dict.items():
entry_point_methods = {method_name: callable_decl for (method_name, callable_decl) in v.callable_declarations.items() if callable_decl.is_entry_point is True}
class_method_dict[k] = entry_point_methods
return class_method_dict
methods = chain.from_iterable(
((typename, method, callable)
for method, callable in methods.items() if callable.is_entrypoint)
for typename, methods in self.get_all_methods_in_application().items()
)
return {
typename: {method: callable for _, method, callable in group}
for typename, group in groupby(methods, key=lambda x: x[0])
}

def get_all_entry_point_classes(self) -> Dict[str, JType]:
"""Returns a dictionary of all entry point classes in the Java code.
Expand All @@ -886,8 +887,8 @@ def get_all_entry_point_classes(self) -> Dict[str, JType]:
with qualified class names as keys.
"""

class_dict = {}
symtab = self.get_symbol_table()
for val in symtab.values():
class_dict.update((k, v) for k, v in val.type_declarations.items() if v.is_entry_point is True)
return class_dict
return {
typename: klass
for typename, klass in self.get_all_classes().items()
if klass.is_entrypoint_class
}
2 changes: 0 additions & 2 deletions cldk/models/java/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,4 @@
JGraphEdges,
)

from .constants_namespace import ConstantsNamespace

__all__ = ["JApplication", "JCallable", "JType", "JCompilationUnit", "JGraphEdges", "ConstantsNamespace"]
41 changes: 0 additions & 41 deletions cldk/models/java/constants_namespace.py

This file was deleted.

66 changes: 2 additions & 64 deletions cldk/models/java/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@
from pdb import set_trace
from pydantic import BaseModel, field_validator, model_validator

from .constants_namespace import ConstantsNamespace

constants = ConstantsNamespace()
context_concrete_class = ContextVar("context_concrete_class") # context var to store class concreteness
_CALLABLES_LOOKUP_TABLE = dict()


Expand Down Expand Up @@ -179,6 +175,7 @@ class JCallable(BaseModel):
referenced_types: List[str]
accessed_fields: List[str]
call_sites: List[JCallSite]
is_entrypoint: bool = False
variable_declarations: List[JVariableDeclaration]
cyclomatic_complexity: int | None

Expand All @@ -188,33 +185,6 @@ def __hash__(self):
"""
return hash(self.declaration)

@model_validator(mode="after")
def detect_entrypoint_method(self):
# check first if the class in which this method exists is concrete or not, by looking at the context var
if context_concrete_class.get():
# convert annotations to the form GET, POST even if they are @GET or @GET('/ID') etc.
annotations_cleaned = [match for annotation in self.annotations for match in re.findall(r"@(.*?)(?:\(|$)", annotation)]

param_type_list = [val.type for val in self.parameters]
# check the param types against known servlet param types
if any(substring in string for substring in param_type_list for string in constants.ENTRY_POINT_METHOD_SERVLET_PARAM_TYPES):
# check if this method is over-riding (only methods that override doGet / doPost etc. will be flagged as first level entry points)
if "Override" in annotations_cleaned:
self.is_entry_point = True
return self

# now check the cleaned annotations against known javax ws annotations
if any(substring in string for substring in annotations_cleaned for string in constants.ENTRY_POINT_METHOD_JAVAX_WS_ANNOTATIONS):
self.is_entry_point = True
return self

# check the cleaned annotations against known spring rest method annotations
if any(substring in string for substring in annotations_cleaned for string in constants.ENTRY_POINT_METHOD_SPRING_ANNOTATIONS):
self.is_entry_point = True
return self
return self


class JType(BaseModel):
"""Represents a Java class or interface.

Expand Down Expand Up @@ -257,44 +227,12 @@ class JType(BaseModel):
modifiers: List[str] = []
annotations: List[str] = []
parent_type: str
is_entrypoint_class: bool = False
nested_type_declerations: List[str] = []
callable_declarations: Dict[str, JCallable] = {}
field_declarations: List[JField] = []
enum_constants: List[JEnumConstant] = []

# first get the data in raw form and check if the class is concrete or not, before any model validation is done
# for this we assume if a class is not an interface or abstract it is concrete
# for abstract classes we will check the modifiers
@model_validator(mode="before")
def check_concrete_class(cls, values):
"""Detects if the class is concrete based on its properties."""
values["is_concrete_class"] = False
if values.get("is_class_or_interface_declaration") and not values.get("is_interface"):
if "abstract" not in values.get("modifiers"):
values["is_concrete_class"] = True
# since the methods in this class need access to the concrete class flag,
# we will store this in a context var - this is a hack
token = context_concrete_class.set(values["is_concrete_class"])
return values

# after model validation is done we populate the is_entry_point flag by checking
# if the class extends or implements known servlet classes
@model_validator(mode="after")
def check_concrete_entry_point(self):
"""Detects if the class is entry point based on its properties."""
if self.is_concrete_class:
if any(substring in string for substring in (self.extends_list + self.implements_list) for string in constants.ENTRY_POINT_SERVLET_CLASSES):
self.is_entry_point = True
return self
# Handle spring classes
# clean annotations - take out @ and any paranehesis along with info in them.
annotations_cleaned = [match for annotation in self.annotations for match in re.findall(r"@(.*?)(?:\(|$)", annotation)]
if any(substring in string for substring in annotations_cleaned for string in constants.ENTRY_POINT_CLASS_SPRING_ANNOTATIONS):
self.is_entry_point = True
return self
# context_concrete.reset()
return self


class JCompilationUnit(BaseModel):
"""Represents a compilation unit in Java.
Expand Down
Loading