Skip to content

Commit 51b34e4

Browse files
committed
Add gen_published_versions.py
Convers unstructured quick-start-module.js in more structured json
1 parent 0ba2a20 commit 51b34e4

File tree

1 file changed

+201
-0
lines changed

1 file changed

+201
-0
lines changed

scripts/gen_published_versions.py

+201
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#!/usr/bin/env python3
2+
# Generates published versions based on unstructured quick-start-module
3+
4+
import json
5+
from urllib.request import urlopen
6+
from typing import Any, Dict, Optional, Union
7+
8+
9+
class ConfigStr:
10+
version: str
11+
conf_type: str
12+
os: str
13+
accel: str
14+
extra: str
15+
16+
@staticmethod
17+
def parse(val: str) -> "ConfigStr":
18+
vals = val.split(",")
19+
assert len(vals) == 5
20+
rc = ConfigStr()
21+
for k, v in zip(["version", "conf_type", "os", "accel", "extra"], vals):
22+
rc.__setattr__(k, v)
23+
return rc
24+
25+
def __repr__(self) -> str:
26+
return self.__dict__.__repr__()
27+
28+
29+
class LibTorchInstruction:
30+
note: Optional[str]
31+
versions: Union[Dict[str, str], str, None]
32+
33+
def __init__(self, note: Optional[str] = None, versions: Union[Dict[str, str], str, None] = None) -> None:
34+
self.note = note
35+
self.versions = versions
36+
37+
@staticmethod
38+
def parse(val: str) -> "LibTorchInstruction":
39+
import re
40+
href_pattern = re.compile("<a href=\'([^']*)\'>([^<]*)</a>")
41+
line_separator = "<br />"
42+
lines = val.split(line_separator)
43+
versions = {}
44+
idx_to_delete = set()
45+
for idx, line in enumerate(lines):
46+
url = href_pattern.findall(line)
47+
if len(url) == 0:
48+
continue
49+
# There should be only one URL per line and value inside and outside of URL shoudl match
50+
assert len(url) == 1
51+
assert url[0][0] == url[0][1].rstrip(), url
52+
versions[lines[idx - 1].strip()] = url[0][0]
53+
idx_to_delete.add(idx - 1)
54+
idx_to_delete.add(idx)
55+
lines = [lines[idx] for idx in range(len(lines)) if idx not in idx_to_delete]
56+
if len(lines) == 1 and len(lines[0]) == 0:
57+
lines = []
58+
return LibTorchInstruction(note=line_separator.join(lines) if len(lines) > 0 else None,
59+
versions=versions if len(versions) > 0 else None)
60+
61+
def __repr__(self) -> str:
62+
return self.__dict__.__repr__()
63+
64+
65+
class PyTorchInstruction:
66+
note: Optional[str]
67+
command: Optional[str]
68+
69+
def __init__(self, note: Optional[str] = None, command: Optional[str] = None) -> None:
70+
self.note = note
71+
self.command = command
72+
73+
@staticmethod
74+
def parse(val: str) -> "PyTorchInstruction":
75+
def is_cmd(cmd: str) -> bool:
76+
return cmd.startswith("pip3 install") or cmd.startswith("conda install")
77+
line_separator = "<br />"
78+
lines = val.split(line_separator)
79+
if is_cmd(lines[-1]):
80+
note = line_separator.join(lines[:-1]) if len(lines) > 1 else None
81+
command = lines[-1]
82+
elif is_cmd(lines[0]):
83+
note = line_separator.join(lines[1:]) if len(lines) > 1 else None
84+
command = lines[0]
85+
else:
86+
note = val
87+
command = None
88+
return PyTorchInstruction(note=note, command=command)
89+
90+
def __repr__(self) -> str:
91+
return self.__dict__.__repr__()
92+
93+
94+
class PublishedAccVersion:
95+
libtorch: Dict[str, LibTorchInstruction]
96+
conda: Dict[str, PyTorchInstruction]
97+
pip: Dict[str, PyTorchInstruction]
98+
99+
def __init__(self):
100+
self.pip = dict()
101+
self.conda = dict()
102+
self.libtorch = dict()
103+
104+
def __repr__(self) -> str:
105+
return self.__dict__.__repr__()
106+
107+
def add_instruction(self, conf: ConfigStr, val: str) -> None:
108+
if conf.conf_type == "libtorch":
109+
self.libtorch[conf.accel] = LibTorchInstruction.parse(val)
110+
elif conf.conf_type == "conda":
111+
self.conda[conf.accel] = PyTorchInstruction.parse(val)
112+
elif conf.conf_type == "pip":
113+
self.pip[conf.accel] = PyTorchInstruction.parse(val)
114+
else:
115+
raise RuntimeError(f"Unknown config type {conf.conf_type}")
116+
117+
118+
class PublishedOSVersion:
119+
linux: PublishedAccVersion
120+
macos: PublishedAccVersion
121+
windows: PublishedAccVersion
122+
123+
def __init__(self):
124+
self.linux = PublishedAccVersion()
125+
self.macos = PublishedAccVersion()
126+
self.windows = PublishedAccVersion()
127+
128+
def add_instruction(self, conf: ConfigStr, val: str) -> None:
129+
if conf.os == "linux":
130+
self.linux.add_instruction(conf, val)
131+
elif conf.os == "macos":
132+
self.macos.add_instruction(conf, val)
133+
elif conf.os == "windows":
134+
self.windows.add_instruction(conf, val)
135+
else:
136+
raise RuntimeError(f"Unknown os type {conf.os}")
137+
138+
def __repr__(self) -> str:
139+
return self.__dict__.__repr__()
140+
141+
142+
class PublishedVersions:
143+
latest_stable: str
144+
latest_lts: str
145+
versions: Dict[str, PublishedOSVersion] = dict()
146+
147+
def __init__(self, latest_stable: str, latest_lts: str) -> None:
148+
self.latest_stable = latest_stable
149+
self.latest_lts = latest_lts
150+
self.versions = dict()
151+
152+
def parse_objects(self, objects: Dict[str, str]) -> None:
153+
for key, val in objects.items():
154+
conf = ConfigStr.parse(key)
155+
if conf.version not in self.versions:
156+
self.versions[conf.version] = PublishedOSVersion()
157+
self.versions[conf.version].add_instruction(conf, val)
158+
if 'stable' in self.versions:
159+
self.versions[self.latest_stable] = self.versions.pop('stable')
160+
if 'lts' in self.versions:
161+
self.versions[self.latest_lts] = self.versions.pop('lts')
162+
163+
164+
def get_objects(commit_hash: str = "0ba2a203045bc94d165d52e56c87ceaa463f4284") -> Dict[str, str]:
165+
"""
166+
Extract install commands as they are currently hardcoded
167+
in pytorch.github.io/assets/quick-start-module.js
168+
"""
169+
raw_base = "raw.githubusercontent.com"
170+
obj_start = "var object = {"
171+
obj_end = "};"
172+
with urlopen(f"https://{raw_base}/pytorch/pytorch.github.io/{commit_hash}/assets/quick-start-module.js") as url:
173+
raw_data = url.read().decode("latin1")
174+
start_idx = raw_data.find(obj_start)
175+
end_idx = raw_data.find(obj_end, start_idx)
176+
# Adjust start end end indexes
177+
start_idx = raw_data.find("{", start_idx, end_idx)
178+
end_idx = raw_data.rfind('"', start_idx, end_idx)
179+
if any(x < 0 for x in [start_idx, end_idx]):
180+
raise RuntimeError("Unexpected raw_data")
181+
return json.loads(raw_data[start_idx:end_idx] + '"}')
182+
183+
184+
def dump_to_file(fname: str, o: Any) -> None:
185+
class DictEncoder(json.JSONEncoder):
186+
def default(self, o):
187+
return o.__dict__
188+
189+
with open(fname, "w") as fp:
190+
json.dump(o, fp, indent=2, cls=DictEncoder)
191+
192+
193+
def main() -> None:
194+
install_objects = get_objects()
195+
rc = PublishedVersions(latest_stable="1.9.0", latest_lts="lts-1.8.2")
196+
rc.parse_objects(install_objects)
197+
dump_to_file("published_versions.json", rc)
198+
199+
200+
if __name__ == "__main__":
201+
main()

0 commit comments

Comments
 (0)