Skip to content

Commit 5051790

Browse files
committed
Bulk generator: Add DCA support.
1 parent e721fc0 commit 5051790

File tree

2 files changed

+128
-2
lines changed

2 files changed

+128
-2
lines changed

cpp/misc/bulk_generation_targets.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
{
2+
"strategy": "dca",
3+
"targets": [
4+
{ "name": "openssl" },
5+
{ "name": "sqlite" }
6+
],
7+
"destination": "cpp/ql/lib/ext/generated"
8+
}

misc/scripts/models-as-data/rust_bulk_generate_mad.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""
22
Experimental script for bulk generation of MaD models based on a list of projects.
3-
4-
Currently the script only targets Rust.
53
"""
64

75
import os.path
@@ -221,6 +219,114 @@ def build_databases_from_projects(language: str, extractor_options, projects: Li
221219
for project, project_dir in project_dirs
222220
]
223221
return database_results
222+
223+
def github(url: str, pat: str, extra_headers: dict[str, str] = {}) -> dict:
224+
"""
225+
Download a JSON file from GitHub using a personal access token (PAT).
226+
Args:
227+
url: The URL to download the JSON file from.
228+
pat: Personal Access Token for GitHub API authentication.
229+
extra_headers: Additional headers to include in the request.
230+
Returns:
231+
The JSON response as a dictionary.
232+
"""
233+
headers = { "Authorization": f"token {pat}" } | extra_headers
234+
response = requests.get(url, headers=headers)
235+
if response.status_code != 200:
236+
print(f"Failed to download JSON: {response.status_code} {response.text}")
237+
sys.exit(1)
238+
else:
239+
return response.json()
240+
241+
def download_artifact(url: str, artifact_name: str, pat: str) -> str:
242+
"""
243+
Download a GitHub Actions artifact from a given URL.
244+
Args:
245+
url: The URL to download the artifact from.
246+
artifact_name: The name of the artifact (used for naming the downloaded file).
247+
pat: Personal Access Token for GitHub API authentication.
248+
Returns:
249+
The path to the downloaded artifact file.
250+
"""
251+
headers = { "Authorization": f"token {pat}", "Accept": "application/vnd.github+json" }
252+
response = requests.get(url, stream=True, headers=headers)
253+
zipName = artifact_name + ".zip"
254+
if response.status_code == 200:
255+
target_zip = os.path.join(build_dir, zipName)
256+
with open(target_zip, "wb") as file:
257+
for chunk in response.iter_content(chunk_size=8192):
258+
file.write(chunk)
259+
print(f"Download complete: {target_zip}")
260+
return target_zip
261+
else:
262+
print(f"Failed to download file. Status code: {response.status_code}")
263+
sys.exit(1)
264+
265+
def remove_extension(filename: str) -> str:
266+
while "." in filename:
267+
filename, _ = os.path.splitext(filename)
268+
return filename
269+
270+
def pretty_name_from_artifact_name(artifact_name: str) -> str:
271+
return artifact_name.split("___")[1]
272+
273+
def download_dca_databases(experiment_name: str, pat: str, projects) -> List[tuple[str, str | None]]:
274+
"""
275+
Download databases from a DCA experiment.
276+
Args:
277+
experiment_name: The name of the DCA experiment to download databases from.
278+
pat: Personal Access Token for GitHub API authentication.
279+
projects: List of projects to download databases for.
280+
Returns:
281+
List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
282+
"""
283+
database_results = []
284+
print("\n=== Finding projects ===")
285+
response = github(f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json", pat)
286+
targets = response["targets"]
287+
for target, data in targets.items():
288+
downloads = data["downloads"]
289+
analyzed_database = downloads["analyzed_database"]
290+
artifact_name = analyzed_database["artifact_name"]
291+
pretty_name = pretty_name_from_artifact_name(artifact_name)
292+
293+
if not pretty_name in [project["name"] for project in projects]:
294+
print(f"Skipping {pretty_name} as it is not in the list of projects")
295+
continue
296+
297+
repository = analyzed_database["repository"]
298+
run_id = analyzed_database["run_id"]
299+
print(f"=== Finding artifact: {artifact_name} ===")
300+
response = github(f"https://api.github.com/repos/{repository}/actions/runs/{run_id}/artifacts", pat, { "Accept": "application/vnd.github+json" })
301+
artifacts = response["artifacts"]
302+
artifact_map = {artifact["name"]: artifact for artifact in artifacts}
303+
print(f"=== Downloading artifact: {artifact_name} ===")
304+
archive_download_url = artifact_map[artifact_name]["archive_download_url"]
305+
artifact_zip_location = download_artifact(archive_download_url, artifact_name, pat)
306+
print(f"=== Extracting artifact: {artifact_name} ===")
307+
# The database is in a zip file, which contains a tar.gz file with the DB
308+
# First we open the zip file
309+
with zipfile.ZipFile(artifact_zip_location, 'r') as zip_ref:
310+
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
311+
# And then we extract it to build_dir/artifact_name
312+
zip_ref.extractall(artifact_unzipped_location)
313+
# And then we iterate over the contents of the extracted directory
314+
# and extract the tar.gz files inside it
315+
for entry in os.listdir(artifact_unzipped_location):
316+
artifact_tar_location = os.path.join(artifact_unzipped_location, entry)
317+
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
318+
# And we just untar it to the same directory as the zip file
319+
tar_ref.extractall(artifact_unzipped_location)
320+
database_results.append((pretty_name, os.path.join(artifact_unzipped_location, remove_extension(entry))))
321+
print(f"\n=== Extracted {len(database_results)} databases ===")
322+
323+
def compare(a, b):
324+
a_index = next(i for i, project in enumerate(projects) if project["name"] == a[0])
325+
b_index = next(i for i, project in enumerate(projects) if project["name"] == b[0])
326+
return a_index - b_index
327+
328+
# Sort the database results based on the order in the projects file
329+
return sorted(database_results, key=cmp_to_key(compare))
224330

225331
def get_destination_for_project(config, name: str) -> str:
226332
return os.path.join(config["destination"], name)
@@ -266,6 +372,16 @@ def main(config, args) -> None:
266372
case "repo":
267373
extractor_options = config.get("extractor_options", [])
268374
database_results = build_databases_from_projects(language, extractor_options, projects)
375+
case "dca":
376+
experiment_name = args.dca
377+
if experiment_name is None:
378+
print("ERROR: --dca argument is required for DCA strategy")
379+
sys.exit(1)
380+
pat = args.pat
381+
if pat is None:
382+
print("ERROR: --pat argument is required for DCA strategy")
383+
sys.exit(1)
384+
database_results = download_dca_databases(experiment_name, pat, projects)
269385

270386
# Phase 3: Generate models for all projects
271387
print("\n=== Phase 3: Generating models ===")
@@ -293,6 +409,8 @@ def main(config, args) -> None:
293409
if __name__ == "__main__":
294410
parser = argparse.ArgumentParser()
295411
parser.add_argument("--config", type=str, help="Path to the configuration file.", required=True)
412+
parser.add_argument("--dca", type=str, help="Name of a DCA run that built all the projects", required=False)
413+
parser.add_argument("--pat", type=str, help="PAT token to grab DCA databases (the same as the one you use for DCA)", required=False)
296414
parser.add_argument("--lang", type=str, help="The language to generate models for", required=True)
297415
parser.add_argument("--with-sources", action="store_true", help="Generate sources", required=False)
298416
parser.add_argument("--with-sinks", action="store_true", help="Generate sinks", required=False)

0 commit comments

Comments
 (0)