Skip to content

Commit a1f45ff

Browse files
committed
Add a script to lift CERT risk assessment tags from help files
1 parent 12df863 commit a1f45ff

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

scripts/add_risk_assessment_tags.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Add risk assessment tags to rule package JSON files.
4+
5+
This script:
6+
1. Iterates through each JSON file in rule_packages directory
7+
2. Looks for CERT-C or CERT-CPP sections
8+
3. For each rule, finds the corresponding markdown file
9+
4. Extracts risk assessment data from the markdown file
10+
5. Adds risk assessment data as tags to each query in the JSON file
11+
"""
12+
13+
import os
14+
import json
15+
import re
16+
import glob
17+
from bs4 import BeautifulSoup
18+
import logging
19+
20+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21+
logger = logging.getLogger(__name__)
22+
23+
def find_rule_packages():
24+
"""Find all JSON rule package files in the rule_packages directory."""
25+
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
26+
rule_packages_dir = os.path.join(repo_root, "rule_packages")
27+
return glob.glob(os.path.join(rule_packages_dir, "**", "*.json"), recursive=True)
28+
29+
def extract_risk_assessment_from_md(md_file_path):
30+
"""Extract risk assessment data from the markdown file."""
31+
risk_data = {}
32+
33+
try:
34+
with open(md_file_path, 'r', encoding='utf-8') as f:
35+
content = f.read()
36+
37+
# Find the Risk Assessment section
38+
risk_section_match = re.search(r'## Risk Assessment(.*?)##', content, re.DOTALL)
39+
if not risk_section_match:
40+
# Try to find it as the last section
41+
risk_section_match = re.search(r'## Risk Assessment(.*?)$', content, re.DOTALL)
42+
if not risk_section_match:
43+
logger.warning(f"No Risk Assessment section found in {md_file_path}")
44+
return risk_data
45+
46+
risk_section = risk_section_match.group(1)
47+
48+
# Look for the table with risk assessment data
49+
table_match = re.search(r'<table>(.*?)</table>', risk_section, re.DOTALL)
50+
if not table_match:
51+
logger.warning(f"No risk assessment table found in {md_file_path}")
52+
return risk_data
53+
54+
table_html = table_match.group(0)
55+
soup = BeautifulSoup(table_html, 'html.parser')
56+
57+
# Find all rows in the table
58+
rows = soup.find_all('tr')
59+
if len(rows) < 2: # Need at least header and data row
60+
logger.warning(f"Incomplete risk assessment table in {md_file_path}")
61+
return risk_data
62+
63+
# Extract headers and values
64+
headers = [th.get_text().strip() for th in rows[0].find_all('th')]
65+
values = [td.get_text().strip() for td in rows[1].find_all('td')]
66+
67+
# Create a dictionary of headers and values
68+
if len(headers) == len(values):
69+
for i, header in enumerate(headers):
70+
risk_data[header] = values[i]
71+
else:
72+
logger.warning(f"Header and value count mismatch in {md_file_path}")
73+
74+
except Exception as e:
75+
logger.error(f"Error extracting risk assessment from {md_file_path}: {e}")
76+
77+
return risk_data
78+
79+
def find_md_file(rule_id, short_name, language):
80+
"""Find the markdown file for the given rule ID and short name."""
81+
repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
82+
md_path = os.path.join(repo_root, language, "cert", "src", "rules", rule_id, f"{short_name}.md")
83+
84+
if os.path.exists(md_path):
85+
return md_path
86+
else:
87+
# Try without short name (sometimes the file is named after the rule ID)
88+
md_path = os.path.join(repo_root, language, "cert", "src", "rules", rule_id, f"{rule_id}.md")
89+
if os.path.exists(md_path):
90+
return md_path
91+
else:
92+
logger.warning(f"Could not find markdown file for {language} rule {rule_id} ({short_name})")
93+
return None
94+
95+
def process_rule_package(rule_package_file):
96+
"""Process a single rule package JSON file."""
97+
try:
98+
with open(rule_package_file, 'r', encoding='utf-8') as f:
99+
data = json.load(f)
100+
101+
modified = False
102+
103+
# Look for CERT-C and CERT-CPP sections
104+
for cert_key in ["CERT-C", "CERT-C++"]:
105+
if cert_key in data:
106+
language = "c" if cert_key == "CERT-C" else "cpp"
107+
108+
# Process each rule in the CERT section
109+
for rule_id, rule_data in data[cert_key].items():
110+
if "queries" in rule_data:
111+
for query in rule_data["queries"]:
112+
if "short_name" in query:
113+
md_file = find_md_file(rule_id, query["short_name"], language)
114+
115+
if md_file:
116+
risk_data = extract_risk_assessment_from_md(md_file)
117+
118+
if risk_data:
119+
# Add risk assessment data as tags
120+
if "tags" not in query:
121+
query["tags"] = []
122+
123+
# Add each risk assessment property as a tag
124+
for key, value in risk_data.items():
125+
key_sanitized = key.lower().replace(" ", "-")
126+
if key_sanitized == "rule":
127+
# skip rule, as that is already in the rule ID
128+
continue
129+
tag = f"external/cert/{key_sanitized}/{value.lower()}"
130+
if tag not in query["tags"]:
131+
query["tags"].append(tag)
132+
modified = True
133+
logger.info(f"Added tag {tag} to {rule_id} ({query['short_name']})")
134+
135+
# Save the modified data back to the file if any changes were made
136+
if modified:
137+
with open(rule_package_file, 'w', encoding='utf-8') as f:
138+
json.dump(data, f, indent=2)
139+
logger.info(f"Updated {rule_package_file}")
140+
else:
141+
logger.info(f"No changes made to {rule_package_file}")
142+
143+
except Exception as e:
144+
logger.error(f"Error processing {rule_package_file}: {e}")
145+
146+
def main():
147+
"""Main function to process all rule packages."""
148+
logger.info("Starting risk assessment tag addition process")
149+
150+
rule_packages = find_rule_packages()
151+
logger.info(f"Found {len(rule_packages)} rule package files")
152+
153+
for rule_package in rule_packages:
154+
logger.info(f"Processing {rule_package}")
155+
process_rule_package(rule_package)
156+
157+
logger.info("Completed risk assessment tag addition process")
158+
159+
if __name__ == "__main__":
160+
main()

0 commit comments

Comments
 (0)