Skip to content

Commit a0280ce

Browse files
committed
Add architecture ectraction logic
1 parent f5e1ad5 commit a0280ce

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

scripts/gen_quick_start_module.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,8 @@ class OperatingSystem(Enum):
3030
ENABLE = "enable"
3131
DISABLE = "disable"
3232

33-
# Mapping json to release matrix is here for now
34-
# TBD drive the mapping via:
35-
# 1. Scanning release matrix and picking 2 latest cuda versions and 1 latest rocm
36-
# 2. Possibility to override the scanning algorithm with arguments passed from workflow
37-
acc_arch_ver_map = {
33+
# Mapping json to release matrix default values
34+
acc_arch_ver_default = {
3835
"nightly": {
3936
"accnone": ("cpu", ""),
4037
"cuda.x": ("cuda", "11.6"),
@@ -49,6 +46,11 @@ class OperatingSystem(Enum):
4946
}
5047
}
5148

49+
# Initialize arch version to default values
50+
# these default values will be overwritten by
51+
# extracted values from the release marix
52+
acc_arch_ver_map = acc_arch_ver_default
53+
5254
LIBTORCH_DWNL_INSTR = {
5355
PRE_CXX11_ABI: "Download here (Pre-cxx11 ABI):",
5456
CXX11_ABI: "Download here (cxx11 ABI):",
@@ -163,6 +165,25 @@ def gen_install_matrix(versions) -> Dict[str, str]:
163165
result[key] = "<br />".join(lines)
164166
return result
165167

168+
# This method is used for extracting two latest verisons of cuda and
169+
# last verion of rocm. It will modify the acc_arch_ver_map object used
170+
# to update getting started page.
171+
def extract_arch_ver_map(release_matrix):
172+
for chan in ("nightly", "release"):
173+
cuda_ver_list = {
174+
x["desired_cuda"]: x["gpu_arch_version"] for x in release_matrix[chan]["linux"]
175+
if x["gpu_arch_type"] == "cuda"
176+
}
177+
rocm_ver_list = {
178+
x["desired_cuda"]: x["gpu_arch_version"] for x in release_matrix[chan]["linux"]
179+
if x["gpu_arch_type"] == "rocm"
180+
}
181+
cuda_list = sorted(cuda_ver_list.values())[-2:]
182+
for cuda_ver, label in zip(cuda_list, ["cuda.x", "cuda.y"]):
183+
acc_arch_ver_map[chan][label] = ("cuda", cuda_ver)
184+
acc_arch_ver_map[chan]["rocm5.x"] = ("rocm", max(rocm_ver_list.values()))
185+
186+
166187
def main():
167188
parser = argparse.ArgumentParser()
168189
parser.add_argument('--autogenerate', dest='autogenerate', action='store_true')
@@ -178,6 +199,7 @@ def main():
178199
for osys in OperatingSystem:
179200
release_matrix[val][osys.value] = read_matrix_for_os(osys, val)
180201

202+
extract_arch_ver_map(release_matrix)
181203
for val in ("nightly", "release"):
182204
update_versions(versions, release_matrix[val], val)
183205

0 commit comments

Comments
 (0)