|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Generates published versions based on unstructured quick-start-module |
| 3 | + |
| 4 | +import json |
| 5 | +from urllib.request import urlopen |
| 6 | +from typing import Any, Dict, Optional, Union |
| 7 | + |
| 8 | + |
| 9 | +class ConfigStr: |
| 10 | + version: str |
| 11 | + conf_type: str |
| 12 | + os: str |
| 13 | + accel: str |
| 14 | + extra: str |
| 15 | + |
| 16 | + @staticmethod |
| 17 | + def parse(val: str) -> "ConfigStr": |
| 18 | + vals = val.split(",") |
| 19 | + assert len(vals) == 5 |
| 20 | + rc = ConfigStr() |
| 21 | + for k, v in zip(["version", "conf_type", "os", "accel", "extra"], vals): |
| 22 | + rc.__setattr__(k, v) |
| 23 | + return rc |
| 24 | + |
| 25 | + def __repr__(self) -> str: |
| 26 | + return self.__dict__.__repr__() |
| 27 | + |
| 28 | + |
| 29 | +class LibTorchInstruction: |
| 30 | + note: Optional[str] |
| 31 | + versions: Union[Dict[str, str], str, None] |
| 32 | + |
| 33 | + def __init__(self, note: Optional[str] = None, versions: Union[Dict[str, str], str, None] = None) -> None: |
| 34 | + self.note = note |
| 35 | + self.versions = versions |
| 36 | + |
| 37 | + @staticmethod |
| 38 | + def parse(val: str) -> "LibTorchInstruction": |
| 39 | + import re |
| 40 | + href_pattern = re.compile("<a href=\'([^']*)\'>([^<]*)</a>") |
| 41 | + line_separator = "<br />" |
| 42 | + lines = val.split(line_separator) |
| 43 | + versions = {} |
| 44 | + idx_to_delete = set() |
| 45 | + for idx, line in enumerate(lines): |
| 46 | + url = href_pattern.findall(line) |
| 47 | + if len(url) == 0: |
| 48 | + continue |
| 49 | + # There should be only one URL per line and value inside and outside of URL shoudl match |
| 50 | + assert len(url) == 1 |
| 51 | + assert url[0][0] == url[0][1].rstrip(), url |
| 52 | + versions[lines[idx - 1].strip()] = url[0][0] |
| 53 | + idx_to_delete.add(idx - 1) |
| 54 | + idx_to_delete.add(idx) |
| 55 | + lines = [lines[idx] for idx in range(len(lines)) if idx not in idx_to_delete] |
| 56 | + if len(lines) == 1 and len(lines[0]) == 0: |
| 57 | + lines = [] |
| 58 | + return LibTorchInstruction(note=line_separator.join(lines) if len(lines) > 0 else None, |
| 59 | + versions=versions if len(versions) > 0 else None) |
| 60 | + |
| 61 | + def __repr__(self) -> str: |
| 62 | + return self.__dict__.__repr__() |
| 63 | + |
| 64 | + |
| 65 | +class PyTorchInstruction: |
| 66 | + note: Optional[str] |
| 67 | + command: Optional[str] |
| 68 | + |
| 69 | + def __init__(self, note: Optional[str] = None, command: Optional[str] = None) -> None: |
| 70 | + self.note = note |
| 71 | + self.command = command |
| 72 | + |
| 73 | + @staticmethod |
| 74 | + def parse(val: str) -> "PyTorchInstruction": |
| 75 | + def is_cmd(cmd: str) -> bool: |
| 76 | + return cmd.startswith("pip3 install") or cmd.startswith("conda install") |
| 77 | + line_separator = "<br />" |
| 78 | + lines = val.split(line_separator) |
| 79 | + if is_cmd(lines[-1]): |
| 80 | + note = line_separator.join(lines[:-1]) if len(lines) > 1 else None |
| 81 | + command = lines[-1] |
| 82 | + elif is_cmd(lines[0]): |
| 83 | + note = line_separator.join(lines[1:]) if len(lines) > 1 else None |
| 84 | + command = lines[0] |
| 85 | + else: |
| 86 | + note = val |
| 87 | + command = None |
| 88 | + return PyTorchInstruction(note=note, command=command) |
| 89 | + |
| 90 | + def __repr__(self) -> str: |
| 91 | + return self.__dict__.__repr__() |
| 92 | + |
| 93 | + |
| 94 | +class PublishedAccVersion: |
| 95 | + libtorch: Dict[str, LibTorchInstruction] |
| 96 | + conda: Dict[str, PyTorchInstruction] |
| 97 | + pip: Dict[str, PyTorchInstruction] |
| 98 | + |
| 99 | + def __init__(self): |
| 100 | + self.pip = dict() |
| 101 | + self.conda = dict() |
| 102 | + self.libtorch = dict() |
| 103 | + |
| 104 | + def __repr__(self) -> str: |
| 105 | + return self.__dict__.__repr__() |
| 106 | + |
| 107 | + def add_instruction(self, conf: ConfigStr, val: str) -> None: |
| 108 | + if conf.conf_type == "libtorch": |
| 109 | + self.libtorch[conf.accel] = LibTorchInstruction.parse(val) |
| 110 | + elif conf.conf_type == "conda": |
| 111 | + self.conda[conf.accel] = PyTorchInstruction.parse(val) |
| 112 | + elif conf.conf_type == "pip": |
| 113 | + self.pip[conf.accel] = PyTorchInstruction.parse(val) |
| 114 | + else: |
| 115 | + raise RuntimeError(f"Unknown config type {conf.conf_type}") |
| 116 | + |
| 117 | + |
| 118 | +class PublishedOSVersion: |
| 119 | + linux: PublishedAccVersion |
| 120 | + macos: PublishedAccVersion |
| 121 | + windows: PublishedAccVersion |
| 122 | + |
| 123 | + def __init__(self): |
| 124 | + self.linux = PublishedAccVersion() |
| 125 | + self.macos = PublishedAccVersion() |
| 126 | + self.windows = PublishedAccVersion() |
| 127 | + |
| 128 | + def add_instruction(self, conf: ConfigStr, val: str) -> None: |
| 129 | + if conf.os == "linux": |
| 130 | + self.linux.add_instruction(conf, val) |
| 131 | + elif conf.os == "macos": |
| 132 | + self.macos.add_instruction(conf, val) |
| 133 | + elif conf.os == "windows": |
| 134 | + self.windows.add_instruction(conf, val) |
| 135 | + else: |
| 136 | + raise RuntimeError(f"Unknown os type {conf.os}") |
| 137 | + |
| 138 | + def __repr__(self) -> str: |
| 139 | + return self.__dict__.__repr__() |
| 140 | + |
| 141 | + |
| 142 | +class PublishedVersions: |
| 143 | + latest_stable: str |
| 144 | + latest_lts: str |
| 145 | + versions: Dict[str, PublishedOSVersion] = dict() |
| 146 | + |
| 147 | + def __init__(self, latest_stable: str, latest_lts: str) -> None: |
| 148 | + self.latest_stable = latest_stable |
| 149 | + self.latest_lts = latest_lts |
| 150 | + self.versions = dict() |
| 151 | + |
| 152 | + def parse_objects(self, objects: Dict[str, str]) -> None: |
| 153 | + for key, val in objects.items(): |
| 154 | + conf = ConfigStr.parse(key) |
| 155 | + if conf.version not in self.versions: |
| 156 | + self.versions[conf.version] = PublishedOSVersion() |
| 157 | + self.versions[conf.version].add_instruction(conf, val) |
| 158 | + if 'stable' in self.versions: |
| 159 | + self.versions[self.latest_stable] = self.versions.pop('stable') |
| 160 | + if 'lts' in self.versions: |
| 161 | + self.versions[self.latest_lts] = self.versions.pop('lts') |
| 162 | + |
| 163 | + |
| 164 | +def get_objects(commit_hash: str = "0ba2a203045bc94d165d52e56c87ceaa463f4284") -> Dict[str, str]: |
| 165 | + """ |
| 166 | + Extract install commands as they are currently hardcoded |
| 167 | + in pytorch.github.io/assets/quick-start-module.js |
| 168 | + """ |
| 169 | + raw_base = "raw.githubusercontent.com" |
| 170 | + obj_start = "var object = {" |
| 171 | + obj_end = "};" |
| 172 | + with urlopen(f"https://{raw_base}/pytorch/pytorch.github.io/{commit_hash}/assets/quick-start-module.js") as url: |
| 173 | + raw_data = url.read().decode("latin1") |
| 174 | + start_idx = raw_data.find(obj_start) |
| 175 | + end_idx = raw_data.find(obj_end, start_idx) |
| 176 | + # Adjust start end end indexes |
| 177 | + start_idx = raw_data.find("{", start_idx, end_idx) |
| 178 | + end_idx = raw_data.rfind('"', start_idx, end_idx) |
| 179 | + if any(x < 0 for x in [start_idx, end_idx]): |
| 180 | + raise RuntimeError("Unexpected raw_data") |
| 181 | + return json.loads(raw_data[start_idx:end_idx] + '"}') |
| 182 | + |
| 183 | + |
| 184 | +def dump_to_file(fname: str, o: Any) -> None: |
| 185 | + class DictEncoder(json.JSONEncoder): |
| 186 | + def default(self, o): |
| 187 | + return o.__dict__ |
| 188 | + |
| 189 | + with open(fname, "w") as fp: |
| 190 | + json.dump(o, fp, indent=2, cls=DictEncoder) |
| 191 | + |
| 192 | + |
| 193 | +def main() -> None: |
| 194 | + install_objects = get_objects() |
| 195 | + rc = PublishedVersions(latest_stable="1.9.0", latest_lts="lts-1.8.2") |
| 196 | + rc.parse_objects(install_objects) |
| 197 | + dump_to_file("published_versions.json", rc) |
| 198 | + |
| 199 | + |
| 200 | +if __name__ == "__main__": |
| 201 | + main() |
0 commit comments