|
1 | 1 | """
|
2 | 2 | Experimental script for bulk generation of MaD models based on a list of projects.
|
3 |
| -
|
4 |
| -Currently the script only targets Rust. |
5 | 3 | """
|
6 | 4 |
|
7 | 5 | import os.path
|
@@ -221,6 +219,114 @@ def build_databases_from_projects(language: str, extractor_options, projects: Li
|
221 | 219 | for project, project_dir in project_dirs
|
222 | 220 | ]
|
223 | 221 | return database_results
|
| 222 | + |
| 223 | +def github(url: str, pat: str, extra_headers: dict[str, str] = {}) -> dict: |
| 224 | + """ |
| 225 | + Download a JSON file from GitHub using a personal access token (PAT). |
| 226 | + Args: |
| 227 | + url: The URL to download the JSON file from. |
| 228 | + pat: Personal Access Token for GitHub API authentication. |
| 229 | + extra_headers: Additional headers to include in the request. |
| 230 | + Returns: |
| 231 | + The JSON response as a dictionary. |
| 232 | + """ |
| 233 | + headers = { "Authorization": f"token {pat}" } | extra_headers |
| 234 | + response = requests.get(url, headers=headers) |
| 235 | + if response.status_code != 200: |
| 236 | + print(f"Failed to download JSON: {response.status_code} {response.text}") |
| 237 | + sys.exit(1) |
| 238 | + else: |
| 239 | + return response.json() |
| 240 | + |
| 241 | +def download_artifact(url: str, artifact_name: str, pat: str) -> str: |
| 242 | + """ |
| 243 | + Download a GitHub Actions artifact from a given URL. |
| 244 | + Args: |
| 245 | + url: The URL to download the artifact from. |
| 246 | + artifact_name: The name of the artifact (used for naming the downloaded file). |
| 247 | + pat: Personal Access Token for GitHub API authentication. |
| 248 | + Returns: |
| 249 | + The path to the downloaded artifact file. |
| 250 | + """ |
| 251 | + headers = { "Authorization": f"token {pat}", "Accept": "application/vnd.github+json" } |
| 252 | + response = requests.get(url, stream=True, headers=headers) |
| 253 | + zipName = artifact_name + ".zip" |
| 254 | + if response.status_code == 200: |
| 255 | + target_zip = os.path.join(build_dir, zipName) |
| 256 | + with open(target_zip, "wb") as file: |
| 257 | + for chunk in response.iter_content(chunk_size=8192): |
| 258 | + file.write(chunk) |
| 259 | + print(f"Download complete: {target_zip}") |
| 260 | + return target_zip |
| 261 | + else: |
| 262 | + print(f"Failed to download file. Status code: {response.status_code}") |
| 263 | + sys.exit(1) |
| 264 | + |
| 265 | +def remove_extension(filename: str) -> str: |
| 266 | + while "." in filename: |
| 267 | + filename, _ = os.path.splitext(filename) |
| 268 | + return filename |
| 269 | + |
| 270 | +def pretty_name_from_artifact_name(artifact_name: str) -> str: |
| 271 | + return artifact_name.split("___")[1] |
| 272 | + |
| 273 | +def download_dca_databases(experiment_name: str, pat: str, projects) -> List[tuple[str, str | None]]: |
| 274 | + """ |
| 275 | + Download databases from a DCA experiment. |
| 276 | + Args: |
| 277 | + experiment_name: The name of the DCA experiment to download databases from. |
| 278 | + pat: Personal Access Token for GitHub API authentication. |
| 279 | + projects: List of projects to download databases for. |
| 280 | + Returns: |
| 281 | + List of (project_name, database_dir) pairs, where database_dir is None if the download failed. |
| 282 | + """ |
| 283 | + database_results = [] |
| 284 | + print("\n=== Finding projects ===") |
| 285 | + response = github(f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json", pat) |
| 286 | + targets = response["targets"] |
| 287 | + for target, data in targets.items(): |
| 288 | + downloads = data["downloads"] |
| 289 | + analyzed_database = downloads["analyzed_database"] |
| 290 | + artifact_name = analyzed_database["artifact_name"] |
| 291 | + pretty_name = pretty_name_from_artifact_name(artifact_name) |
| 292 | + |
| 293 | + if not pretty_name in [project["name"] for project in projects]: |
| 294 | + print(f"Skipping {pretty_name} as it is not in the list of projects") |
| 295 | + continue |
| 296 | + |
| 297 | + repository = analyzed_database["repository"] |
| 298 | + run_id = analyzed_database["run_id"] |
| 299 | + print(f"=== Finding artifact: {artifact_name} ===") |
| 300 | + response = github(f"https://api.github.com/repos/{repository}/actions/runs/{run_id}/artifacts", pat, { "Accept": "application/vnd.github+json" }) |
| 301 | + artifacts = response["artifacts"] |
| 302 | + artifact_map = {artifact["name"]: artifact for artifact in artifacts} |
| 303 | + print(f"=== Downloading artifact: {artifact_name} ===") |
| 304 | + archive_download_url = artifact_map[artifact_name]["archive_download_url"] |
| 305 | + artifact_zip_location = download_artifact(archive_download_url, artifact_name, pat) |
| 306 | + print(f"=== Extracting artifact: {artifact_name} ===") |
| 307 | + # The database is in a zip file, which contains a tar.gz file with the DB |
| 308 | + # First we open the zip file |
| 309 | + with zipfile.ZipFile(artifact_zip_location, 'r') as zip_ref: |
| 310 | + artifact_unzipped_location = os.path.join(build_dir, artifact_name) |
| 311 | + # And then we extract it to build_dir/artifact_name |
| 312 | + zip_ref.extractall(artifact_unzipped_location) |
| 313 | + # And then we iterate over the contents of the extracted directory |
| 314 | + # and extract the tar.gz files inside it |
| 315 | + for entry in os.listdir(artifact_unzipped_location): |
| 316 | + artifact_tar_location = os.path.join(artifact_unzipped_location, entry) |
| 317 | + with tarfile.open(artifact_tar_location, "r:gz") as tar_ref: |
| 318 | + # And we just untar it to the same directory as the zip file |
| 319 | + tar_ref.extractall(artifact_unzipped_location) |
| 320 | + database_results.append((pretty_name, os.path.join(artifact_unzipped_location, remove_extension(entry)))) |
| 321 | + print(f"\n=== Extracted {len(database_results)} databases ===") |
| 322 | + |
| 323 | + def compare(a, b): |
| 324 | + a_index = next(i for i, project in enumerate(projects) if project["name"] == a[0]) |
| 325 | + b_index = next(i for i, project in enumerate(projects) if project["name"] == b[0]) |
| 326 | + return a_index - b_index |
| 327 | + |
| 328 | + # Sort the database results based on the order in the projects file |
| 329 | + return sorted(database_results, key=cmp_to_key(compare)) |
224 | 330 |
|
225 | 331 | def get_destination_for_project(config, name: str) -> str:
|
226 | 332 | return os.path.join(config["destination"], name)
|
@@ -266,6 +372,16 @@ def main(config, args) -> None:
|
266 | 372 | case "repo":
|
267 | 373 | extractor_options = config.get("extractor_options", [])
|
268 | 374 | database_results = build_databases_from_projects(language, extractor_options, projects)
|
| 375 | + case "dca": |
| 376 | + experiment_name = args.dca |
| 377 | + if experiment_name is None: |
| 378 | + print("ERROR: --dca argument is required for DCA strategy") |
| 379 | + sys.exit(1) |
| 380 | + pat = args.pat |
| 381 | + if pat is None: |
| 382 | + print("ERROR: --pat argument is required for DCA strategy") |
| 383 | + sys.exit(1) |
| 384 | + database_results = download_dca_databases(experiment_name, pat, projects) |
269 | 385 |
|
270 | 386 | # Phase 3: Generate models for all projects
|
271 | 387 | print("\n=== Phase 3: Generating models ===")
|
@@ -293,6 +409,8 @@ def main(config, args) -> None:
|
293 | 409 | if __name__ == "__main__":
|
294 | 410 | parser = argparse.ArgumentParser()
|
295 | 411 | parser.add_argument("--config", type=str, help="Path to the configuration file.", required=True)
|
| 412 | + parser.add_argument("--dca", type=str, help="Name of a DCA run that built all the projects", required=False) |
| 413 | + parser.add_argument("--pat", type=str, help="PAT token to grab DCA databases (the same as the one you use for DCA)", required=False) |
296 | 414 | parser.add_argument("--lang", type=str, help="The language to generate models for", required=True)
|
297 | 415 | parser.add_argument("--with-sources", action="store_true", help="Generate sources", required=False)
|
298 | 416 | parser.add_argument("--with-sinks", action="store_true", help="Generate sinks", required=False)
|
|
0 commit comments