import os from collections import OrderedDict from pathlib import Path import torch import torch._prims as prims from torchgen.gen import parse_native_yaml ROOT = Path(__file__).absolute().parents[3] NATIVE_FUNCTION_YAML_PATH = ROOT / "aten/src/ATen/native/native_functions.yaml" TAGS_YAML_PATH = ROOT / "aten/src/ATen/native/tags.yaml" BUILD_DIR = "build/ir" ATEN_OPS_CSV_FILE = "aten_ops.csv" PRIMS_OPS_CSV_FILE = "prims_ops.csv" def get_aten(): parsed_yaml = parse_native_yaml(NATIVE_FUNCTION_YAML_PATH, TAGS_YAML_PATH) native_functions = parsed_yaml.native_functions aten_ops = OrderedDict() for function in native_functions: if "core" in function.tags: op_name = str(function.func.name) aten_ops[op_name] = function op_schema_pairs = [] for key, op in sorted(aten_ops.items()): op_name = f"aten.{key}" schema = str(op.func).replace("*", r"\*") op_schema_pairs.append((op_name, schema)) return op_schema_pairs def get_prims(): op_schema_pairs = [] for op_name in prims.__all__: op_overload = getattr(prims, op_name, None) if not isinstance(op_overload, torch._ops.OpOverload): continue op_overloadpacket = op_overload.overloadpacket op_name = str(op_overload).replace(".default", "") schema = op_overloadpacket.schema.replace("*", r"\*") op_schema_pairs.append((op_name, schema)) return op_schema_pairs def main(): aten_ops_list = get_aten() prims_ops_list = get_prims() os.makedirs(BUILD_DIR, exist_ok=True) with open(os.path.join(BUILD_DIR, ATEN_OPS_CSV_FILE), "w") as f: f.write("Operator,Schema\n") for name, schema in aten_ops_list: f.write(f'"``{name}``","{schema}"\n') with open(os.path.join(BUILD_DIR, PRIMS_OPS_CSV_FILE), "w") as f: f.write("Operator,Schema\n") for name, schema in prims_ops_list: f.write(f'"``{name}``","{schema}"\n') if __name__ == "__main__": main()