@@ -30,11 +30,8 @@ class OperatingSystem(Enum):
30
30
ENABLE = "enable"
31
31
DISABLE = "disable"
32
32
33
- # Mapping json to release matrix is here for now
34
- # TBD drive the mapping via:
35
- # 1. Scanning release matrix and picking 2 latest cuda versions and 1 latest rocm
36
- # 2. Possibility to override the scanning algorithm with arguments passed from workflow
37
- acc_arch_ver_map = {
33
+ # Mapping json to release matrix default values
34
+ acc_arch_ver_default = {
38
35
"nightly" : {
39
36
"accnone" : ("cpu" , "" ),
40
37
"cuda.x" : ("cuda" , "11.6" ),
@@ -49,6 +46,11 @@ class OperatingSystem(Enum):
49
46
}
50
47
}
51
48
49
+ # Initialize arch version to default values
50
+ # these default values will be overwritten by
51
+ # extracted values from the release marix
52
+ acc_arch_ver_map = acc_arch_ver_default
53
+
52
54
LIBTORCH_DWNL_INSTR = {
53
55
PRE_CXX11_ABI : "Download here (Pre-cxx11 ABI):" ,
54
56
CXX11_ABI : "Download here (cxx11 ABI):" ,
@@ -163,6 +165,25 @@ def gen_install_matrix(versions) -> Dict[str, str]:
163
165
result [key ] = "<br />" .join (lines )
164
166
return result
165
167
168
+ # This method is used for extracting two latest verisons of cuda and
169
+ # last verion of rocm. It will modify the acc_arch_ver_map object used
170
+ # to update getting started page.
171
+ def extract_arch_ver_map (release_matrix ):
172
+ for chan in ("nightly" , "release" ):
173
+ cuda_ver_list = {
174
+ x ["desired_cuda" ]: x ["gpu_arch_version" ] for x in release_matrix [chan ]["linux" ]
175
+ if x ["gpu_arch_type" ] == "cuda"
176
+ }
177
+ rocm_ver_list = {
178
+ x ["desired_cuda" ]: x ["gpu_arch_version" ] for x in release_matrix [chan ]["linux" ]
179
+ if x ["gpu_arch_type" ] == "rocm"
180
+ }
181
+ cuda_list = sorted (cuda_ver_list .values ())[- 2 :]
182
+ for cuda_ver , label in zip (cuda_list , ["cuda.x" , "cuda.y" ]):
183
+ acc_arch_ver_map [chan ][label ] = ("cuda" , cuda_ver )
184
+ acc_arch_ver_map [chan ]["rocm5.x" ] = ("rocm" , max (rocm_ver_list .values ()))
185
+
186
+
166
187
def main ():
167
188
parser = argparse .ArgumentParser ()
168
189
parser .add_argument ('--autogenerate' , dest = 'autogenerate' , action = 'store_true' )
@@ -178,6 +199,7 @@ def main():
178
199
for osys in OperatingSystem :
179
200
release_matrix [val ][osys .value ] = read_matrix_for_os (osys , val )
180
201
202
+ extract_arch_ver_map (release_matrix )
181
203
for val in ("nightly" , "release" ):
182
204
update_versions (versions , release_matrix [val ], val )
183
205
0 commit comments